RLHF with Dagster and Modal

Part 1. Introduction

In today’s blog, we are going to talk about how to train our own RLHF model.

Not only will we train it, but we will also keep it tight and clean with the help of Dagster and Modal. This way, your pipeline will be reusable, cost-effective, and fast, allowing for multiple iterations, and scalability when needed.

As you might guess, I am a big fan of the Dagster and Modal combination for my training pipelines, so the training will happen within these two products!

This is a self-sufficient post—you can jump into the code right away. However, if you want to gain an intuitive understanding of what RLHF is and how it differs from simple supervised finetuning, check out this blog post: RLHF: Reinforcement Learning from Human Feedback. If you want to do simple supervised finetuning with the same stack, check this one: “How to Fine-Tune LLMs in 2024 with Hugging Face”, but with Dagster, Modal and Llama3.

The main code I am going to build on top of is taken from this blog post: RLHF in 2024 with DPO & Hugging Face from Philipp Schmid. It is amazing but difficult to reuse repeatedly, so I want to add an infrastructure touch and move it away from notebooks!

Okay, enough talk—let’s jump into it!

Part 2. High level picture

The final state we want to be in is very simple: 

We should handle all RLHF dataset-related matters in the data assets group, actual model training in the model assets group, and evaluation in the ml_benchmark asset group. Each asset corresponds to a dataset, a trained model, evaluation metrics, etc. For more details on what assets are, refer to Asset definitions.

Part 3. Docker & CI

First things first — let’s handle the boring but important stuff. Docker simplifies managing dependencies and environments, eliminating the “it works on my machine” problem and similar issues. 

Let’s solve this once by migrating everything to Docker. Here is my Dockerfile for this training:

FROM huggingface/transformers-pytorch-gpu:4.35.2
WORKDIR /app

COPY requirements.txt requirements.txt
RUN pip3 install --no-cache-dir -r requirements.txt
RUN MAX_JOBS=4 pip install flash-attn==2.5.9.post1 --no-build-isolation
RUN git clone https://github.com/philschmid/FastChat.git
RUN pip install -e "./FastChat[model_worker,llm_judge]"
RUN pip install matplotlib==3.7.3 tabulate==0.9.0

ENV DAGSTER_HOME /app/dagster_data
RUN mkdir -p $DAGSTER_HOME
ENV PYTHONPATH /app
RUN ln -s /usr/bin/python3 /usr/bin/python
COPY rlhf_training rlhf_training 

CMD dagster dev -f rlhf_training/llm_rlhf.py -p 3000 -h 0.0.0.0

And of course, because I am a super lazy person, CI is something I always try to have. In this case, there’s no way I would build these images myself, so I use GitHub Actions to do this for me.

name: Publish Docker image

on:
  push:
    branches:
      - main

jobs:
  container:
    runs-on: ubuntu-latest
    permissions:
      contents: read
      packages: write
    steps:

      - name: Checkout repository
        uses: actions/checkout@v4

      - name: Log in to the Container registry
        uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
        with:
          registry: ghcr.io
          username: ${{ github.actor }}
          password: ${{ secrets.GITHUB_TOKEN }}
    
      - name: Extract metadata (tags, labels) for Docker
        id: meta
        uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
        with:
          images: ghcr.io/kyryl-opens-ml/rlfh-dagster-modal

      # See explanation: https://github.com/orgs/community/discussions/25678 
      - name: Clean disk
        run: |
          rm -rf /opt/hostedtoolcache

      - name: Build and push Docker image
        uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
        with:
          context: .
          push: true
          tags: ${{ steps.meta.outputs.tags }}
          labels: ${{ steps.meta.outputs.labels }}

Free and automated, sweet! Now we can work on the interesting stuff. 

If you want to re-use my docker image just pull it from registry: 

docker pull ghcr.io/kyryl-opens-ml/rlfh-dagster-modal:main

Part 4. Data

Without overstating, data is the most crucial part of any ML system. If you can establish a dataset flow from production with your real users in the form of prompt -> liked/disliked results, your competitive edge would be huge. However, not many companies are doing this successfully.

If you want to learn more about building and collecting RLHF data, make sure to check out these resources:

For this particular exercise, we are going to use a dataset from this blog post: RLHF in 2024 with DPO & Hugging Face, which can be found here.

It has a very simple format: (prompt, selected response, rejected response). That’s it.

We are going to split it into training and evaluation subsets and structure them as DataConfig and three assets: rlhf_dataset, train_dataset, and eval_dataset. The full code for the data assets looks like this:

from dagster import Config, asset, MetadataValue, AssetExecutionContext
from datasets import load_dataset
from random import randint
from transformers import AutoTokenizer
from typing import Dict
from datasets import Dataset


class DataConfig(Config):
    dataset_name: str = "argilla/ultrafeedback-binarized-preferences-cleaned"
    train_data_path: str = "train_dataset.json"
    eval_data_path: str = "eval_dataset.json"
    eval_size: float = 0.1
    sample_training: int = 10_000
    model_id: str = "cognitivecomputations/dolphin-2.1-mistral-7b"


@asset(compute_kind="python")
def rlhf_dataset(config: DataConfig) -> Dict[str, str]:
    tokenizer = AutoTokenizer.from_pretrained(config.model_id)

    dataset = load_dataset(config.dataset_name, split="train")
    dataset = dataset.shuffle().select(range(config.sample_training))

    def rec_extract_assistant_messages(messages, index=-1):
        """Recursively extract the last assistant messages from the end of the conversation."""
        if messages[index]["role"] == "assistant":
            return [messages[index]]
        else:
            return rec_extract_assistant_messages(messages, index - 1)

    DEFAULT_SYSTEM_MESSAGE = "You are Dolphin, a helpful AI assistant."

    def create_triplets(
        example, tokenizer, default_system_message=DEFAULT_SYSTEM_MESSAGE
    ):
        """Create the triplets (prompt, chosen, rejected)"""
        prompt_messages = example["chosen"][:-1]
        if example["chosen"][0]["role"] != "system":
            prompt_messages.insert(
                0, {"role": "system", "content": default_system_message}
            )

        chosen_messages = rec_extract_assistant_messages(example["chosen"])
        rejected_messages = rec_extract_assistant_messages(example["rejected"])

        return {
            "prompt": tokenizer.apply_chat_template(
                prompt_messages, tokenize=False
            ),
            "chosen": tokenizer.apply_chat_template(
                chosen_messages, tokenize=False
            ),
            "rejected": tokenizer.apply_chat_template(
                rejected_messages, tokenize=False
            ),
        }

    dataset = dataset.map(
        create_triplets,
        remove_columns=dataset.features,
        fn_kwargs={"tokenizer": tokenizer},
    )
    dataset = dataset.train_test_split(test_size=config.eval_size)

    dataset["train"].to_json(config.train_data_path, orient="records")
    dataset["test"].to_json(config.eval_data_path, orient="records")

    return {
        "train_path": config.train_data_path,
        "test_path": config.eval_data_path,
    }


@asset(compute_kind="python")
def train_dataset(
    context: AssetExecutionContext, rlhf_dataset: Dict[str, str]
) -> Dataset:
    dataset = load_dataset(
        "json", data_files=rlhf_dataset["train_path"], split="train"
    )
    context.add_output_metadata(
        {
            "len": MetadataValue.int(len(dataset)),
            "sample": MetadataValue.json(dataset[randint(0, len(dataset))]),
        }
    )
    return dataset


@asset(compute_kind="python")
def eval_dataset(
    context: AssetExecutionContext, rlhf_dataset: Dict[str, str]
) -> Dataset:
    dataset = load_dataset(
        "json", data_files=rlhf_dataset["test_path"], split="train"
    )
    context.add_output_metadata(
        {
            "len": MetadataValue.int(len(dataset)),
            "sample": MetadataValue.json(dataset[randint(0, len(dataset))]),
        }
    )
    return dataset

And the Dagster visualization for these assets looks like this:

Part 5. RLHF model training

Now, let’s move on to the model training!

At a high level, our training process looks very simple. The only changes from the classical fine-tuning are:

  • Dataset in a different format: remember we follow the (prompt, chosen, rejected) structure.
  • Using the DPO Trainer instead of the SFT trainer.

Clarification: technically speaking RLHF and DPO are 2 different techniques for preference training or alignment, but during this blog post I will use RLHF for all alignment techniques, such as: DPO, KTO, CPT, etc.

The trained model is saved on Hugging Face: https://huggingface.co/kyryl-opens-ml/doplhin-dpo-1-epoch.

We structure the model code as ModelTrainingConfig and three assets: trained_model, model_card, and vibe_check. The full code for the model assets looks like this:

from dagster import Config, asset, MetadataValue, AssetExecutionContext
from huggingface_hub import hf_hub_download
from datasets import Dataset
import modal


class ModelTrainingConfig(Config):
    pretrained_model_id: str = "cognitivecomputations/dolphin-2.1-mistral-7b"
    peft_model_id: str = "doplhin-dpo-1-epoch"
    num_train_epochs: float = 1


@asset(compute_kind="modal")
def trained_model(
    context: AssetExecutionContext,
    config: ModelTrainingConfig,
    train_dataset: Dataset,
    eval_dataset: Dataset,
) -> str:
    run_training_modal_function = modal.Function.lookup(
        "rlfh-dagster-modal", "run_training_modal"
    )
    hub_model_id = run_training_modal_function.remote(
        pretrained_model_id=config.pretrained_model_id,
        rlhf_model_id=config.peft_model_id,
        train_dataset_pandas=train_dataset.to_pandas(),
        eval_dataset_pands=eval_dataset.to_pandas(),
        num_train_epochs=config.num_train_epochs,
    )
    context.add_output_metadata(
        {
            "model_url": MetadataValue.url(
                f"https://huggingface.co/{hub_model_id}"
            )
        }
    )
    return hub_model_id


@asset(compute_kind="python")
def model_card(context: AssetExecutionContext, trained_model: str) -> str:
    model_card_path = hf_hub_download(
        repo_id=trained_model, filename="README.md"
    )
    with open(model_card_path, "r") as f:
        content = f.read()
    context.add_output_metadata({"content": MetadataValue.md(content)})
    return content


@asset(compute_kind="modal")
def vibe_check(context: AssetExecutionContext, trained_model: str):
    prompts = [
        "A rectangular garden has a length of 25 feet and a width of 15 feet. If you want to build a fence around the entire garden, how many feet of fencing will you need?",
        "It's Bengay for muscle relief, a combination of methyl salicylate, menthol, and what other active ingredient commonly found in aspirin?",
        "How can i get rid of llamas in my backyard?",
    ]

    run_sample_inference_modal_function = modal.Function.lookup(
        "rlfh-dagster-modal", "run_sample_inference_modal"
    )

    inference_samples = run_sample_inference_modal_function.remote(
        prompts=prompts, hub_model_id=trained_model
    )
    context.add_output_metadata(
        {
            "samples": MetadataValue.json(inference_samples),
        }
    )

And the Dagster visualization for these assets looks like this:

And we can also perform a model vibe_check by running some ad-hoc prompts on the trained model and displaying the results in a Dagster UI.

Have you noticed something strange? Where is the actual training function, and why do we call them .remote? And why do the trained model and vibe check have Modal as the compute_kind instead of Python?

Well, this is because we are using Modal for running GPU operations in our Dagster pipeline. So, for those assets that don’t need a GPU, you can call your training as an HTTP function.

To enable this, you need to define Modal functions:

import modal
from modal import Image
import pandas as pd
import os
from typing import List, Dict

app = modal.App("rlfh-dagster-modal")
env = {"HF_TOKEN": os.getenv("HF_TOKEN")}
custom_image = Image.from_registry("ghcr.io/kyryl-opens-ml/rlfh-dagster-modal:main").env(env)
mount = modal.Mount.from_local_python_packages("rlhf_training", "rlhf_training")
timeout = 6 * 60 * 60


@app.function(image=custom_image, gpu="A100", mounts=[mount], timeout=timeout)
def run_training_modal(
    pretrained_model_id: str,
    rlhf_model_id: str,
    train_dataset_pandas: pd.DataFrame,
    eval_dataset_pands: pd.DataFrame,
    num_train_epochs: float,
):
    from datasets import Dataset
    from rlhf_training.utils import run_training

    model_url = run_training(
        pretrained_model_id=pretrained_model_id,
        rlhf_model_id=rlhf_model_id,
        train_dataset=Dataset.from_pandas(train_dataset_pandas),
        eval_dataset=Dataset.from_pandas(eval_dataset_pands),
        num_train_epochs=num_train_epochs,
    )
    return model_url


@app.function(image=custom_image, gpu="A100", mounts=[mount], timeout=timeout)
def run_sample_inference_modal(
    prompts: List[str], hub_model_id: str
) -> List[Dict[str, str]]:
    from rlhf_training.utils import run_sample_inference

    inference_samples = run_sample_inference(
        prompts=prompts, hub_model_id=hub_model_id
    )
    return inference_samples

And to deploy them: 

modal deploy ./rlhf_training/smodal_functions.py

After that, in the Modal UI, you should be able to see your application and the functions assigned to it:

While the actual training & inference code looks like this: 

from datasets import Dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
import json
from trl import DPOTrainer
import base64
from typing import List, Dict


def run_training(
    pretrained_model_id: str,
    rlhf_model_id: str,
    train_dataset: Dataset,
    eval_dataset: Dataset,
    num_train_epochs: int = 1,
) -> str:
    # BitsAndBytesConfig int-4 config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_id,
        device_map="auto",
        use_cache=False,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        quantization_config=bnb_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"  # to prevent errors with FA
    tokenizer.truncation_side = (
        "left"  # to prevent cutting off last generation
    )

    prompt_length = 1024
    max_seq_length = 1512

    # LoRA config based on QLoRA paper & Sebastian Raschka experiment
    peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.05,
        r=256,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
    )

    args = TrainingArguments(
        output_dir=rlhf_model_id,  # directory to save and repository id
        num_train_epochs=num_train_epochs,  # number of training epochs
        per_device_train_batch_size=2,  # batch size per device during training
        per_device_eval_batch_size=2,  # batch size for evaluation
        gradient_accumulation_steps=4,  # number of steps before performing a backward/update pass
        gradient_checkpointing=True,  # use gradient checkpointing to save memory
        optim="adamw_torch_fused",  # use fused adamw optimizer
        learning_rate=5e-5,  # 10x higher LR than QLoRA paper
        max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
        warmup_ratio=0.1,  # warmup ratio based on QLoRA paper
        lr_scheduler_type="cosine",  # use cosine learning rate scheduler
        logging_steps=25,  # log every 25 steps
        save_steps=500,  # when to save checkpoint
        save_total_limit=2,  # limit the total amount of checkpoints
        evaluation_strategy="steps",  # evaluate every 1000 steps
        eval_steps=700,  # when to evaluate
        bf16=True,  # use bfloat16 precision
        tf32=True,  # use tf32 precision
        push_to_hub=True,  # push model to hub
        report_to="tensorboard",  # report metrics to tensorboard
    )

    dpo_args = {
        "beta": 0.1,  # The beta factor in DPO loss. Higher beta means less divergence
        "loss_type": "sigmoid",  # The loss type for DPO.
    }

    trainer = DPOTrainer(
        model,
        ref_model=None,  # set to none since we use peft
        peft_config=peft_config,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_length=max_seq_length,
        max_prompt_length=prompt_length,
        beta=dpo_args["beta"],
        loss_type=dpo_args["loss_type"],
    )

    trainer.model.print_trainable_parameters()
    # start training, the model will be automatically saved to the hub and the output directory
    train_result = trainer.train()
    train_metrics = train_result.metrics
    trainer.log_metrics("train", train_metrics)
    trainer.save_metrics("train", train_metrics)
    trainer.save_model()

    kwargs = {
        "finetuned_from": pretrained_model_id,
        "language": "en",
    }
    trainer.create_model_card(**kwargs)

    hub_model_id = trainer.hub_model_id
    del trainer
    del model
    torch.cuda.empty_cache()

    return hub_model_id


def run_sample_inference(
    prompts: List[str], hub_model_id: str
) -> List[Dict[str, str]]:
    model = AutoPeftModelForCausalLM.from_pretrained(
        hub_model_id, device_map="auto", torch_dtype=torch.float16
    )
    tokenizer = AutoTokenizer.from_pretrained(hub_model_id)

    # load into pipeline
    merged_model = model.merge_and_unload()
    pipe = pipeline("text-generation", model=merged_model, tokenizer=tokenizer)

    inference_samples = []
    for prompt in prompts:
        outputs = pipe(
            prompt,
            max_new_tokens=2048,
            do_sample=True,
            temperature=1.0,
            top_k=50,
            top_p=0.9,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        inference_samples.append(
            {
                "prompt": prompt,
                "generated-answer": outputs[0]["generated_text"][
                    len(prompt) :
                ].strip(),
            }
        )
    return inference_samples

Given that we want to stay within budget, we used QLoRA for training our RLHF model and limited the training to 1 GPU. Most hyperparameters and training code were adapted from this blog post: RLHF in 2024 with DPO & Hugging Face

Part 6. MT-Bench evaluation

The last optional part is to evaluate our RLHF-trained model on MT-Bench.

Why is this part optional? Because MT-Bench is a generic benchmark, and unless you are developing a general-purpose LLM, you should not use it. Instead, make sure to invest time and effort into constructing your own evaluation process. However, understanding how open-source evaluation benchmarks work is useful.

Let’s start with the questions:

MT-Bench has a set of predefined questions for your LLM to answer:

So, you ask your newly trained RLHF model and the model you used as the starting point to answer those questions. I called my models my-sft (the model I started with) and my-rlhf (the newly RLHF-trained model). The answers should be in the following format:

And same for my-rlhf

After that, you ask the GPT4 model to evaluate which answers are better (what’s called the LLM as judge pattern). You can visualize the results like this: showing when GPT4 prefers the my-sft answer or the my-rlhf answer.

The full code for the mt_benchmark assets looks like this:

from dagster import Config, asset, MetadataValue, AssetExecutionContext
import subprocess
from rlhf_training.utils import read_jsonl, encode_image


class MTBenchConfig(Config):
    mt_bench_questions_path: str = "data/mt_bench/question.jsonl"
    original_responses_path: str = "data/mt_bench/model_answer/my-sft.jsonl"
    rlhf_responses_path: str = "data/mt_bench/model_answer/my-rlhf.jsonl"

    sft_model_id: str = "my-sft"
    rlhf_model_id: str = "my-rlhf"


@asset(compute_kind="python")
def mt_bench_questions(context: AssetExecutionContext, config: MTBenchConfig):
    _mt_bench_questions = read_jsonl(config.mt_bench_questions_path)

    context.add_output_metadata(
        {
            "mt_bench_questions": MetadataValue.json(_mt_bench_questions),
        }
    )
    return config.mt_bench_questions_path


@asset(compute_kind="python")
def original_responses(
    context: AssetExecutionContext, config: MTBenchConfig, mt_bench_questions
):
    cmd = f"python FastChat/fastchat/llm_judge/gen_model_answer.py  --model-id {config.sft_model_id} --model-path cognitivecomputations/dolphin-2.1-mistral-7b"
    result = subprocess.run(
        cmd.split(), check=True, capture_output=True, text=True
    )
    _original_responses = read_jsonl(config.original_responses_path)

    context.add_output_metadata(
        {
            "original_responses": MetadataValue.json(_original_responses),
            "cli_output": MetadataValue.text(result.stdout),
        }
    )
    return config.original_responses_path


@asset(compute_kind="python")
def rlhf_responses(
    context: AssetExecutionContext,
    config: MTBenchConfig,
    mt_bench_questions,
    trained_model: str,
):
    cmd = f"python FastChat/fastchat/llm_judge/gen_model_answer.py --model-id {config.rlhf_model_id}  --model-path {trained_model}"
    result = subprocess.run(cmd.split(), check=True, capture_output=True, text=True)
    _rlhf_responses = read_jsonl(config.rlhf_responses_path)

    context.add_output_metadata(
        {
            "rlhf_responses": MetadataValue.json(_rlhf_responses),
            "cli_output": MetadataValue.text(result.stdout),
        }
    )
    return config.rlhf_responses_path


@asset(compute_kind="python")
def judgment_results(
    context: AssetExecutionContext,
    config: MTBenchConfig,
    original_responses,
    rlhf_responses,
):
    cmd = f"python FastChat/fastchat/llm_judge/gen_judgment.py --model-list {config.sft_model_id} {config.rlhf_model_id} --judge-model gpt-4-1106-preview --mode pairwise-all"
    result_gen_judgment = subprocess.run(cmd.split(), check=True, capture_output=True, text=True)

    cmd = f"python FastChat/fastchat/llm_judge/show_result.py --input-file ./data/mt_bench/model_judgment/gpt-4-1106-preview_pair.jsonl --model-list {config.sft_model_id} {config.rlhf_model_id} --judge-model gpt-4-1106-preview --mode pairwise-all"
    result_show_result = subprocess.run(cmd.split(), check=True, capture_output=True, text=True)

    image_path = "win_rate_gpt-4-1106-preview.png"
    image_data = encode_image(image_path)
    md_content = f"![img](data:image/png;base64,{image_data})"

    context.add_output_metadata(
        {
            "plot": MetadataValue.md(md_content),
        }
    )

And And the Dagster visualization: 

Part 7. Conclusion

This wasn’t easy, was it?

But from now on, building your own RLHF models is not such a mystery and can be done in a reusable fashion!

Just to refresh, we started with data, built a training pipeline with the help of Modal, and incorporated the MT-Bench evaluation process into one Dagster set of assets, which finally looks like this:

Now, feel free to break and build on top of it! Add your own data, change the RLHF training method or model, and modify the evaluation process. The code, model checkpoint, and Docker image are freely available!

Follow up

If you’ve reached this point in my blog, wow, that’s great! If you liked what you saw here and want to learn fundamentally how to build ML systems in production, check my webpage where I cover most of these topics in great detail!

2 thoughts on “RLHF with Dagster and Modal”

  1. Pingback: Train RLHF Models with Dagster and Modal: Step-by-Step Guide – GANjeh

  2. Pingback: Automate Dagster with LLMs: An MCP Server Tutorial

Leave a Reply

Scroll to Top

Discover more from Kyryl Opens ML

Subscribe now to keep reading and get access to the full archive.

Continue reading