Fine-tune Llama Vision models with TRL 🚀

#31
by lewtun HF staff - opened

Hello everyone, it's Lewis here from the TRL team 👋

We've added support for the Llama 3.2 Vision models to TRL's SFTTrainer, so you can fine-tune them in under 80 lines of code like this:

import torch
from accelerate import Accelerator
from datasets import load_dataset

from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration

from trl import (
    ModelConfig,
    SFTConfig,
    SFTTrainer
)

##########################
# Load model and processor
##########################
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.bfloat16)

#######################################################
# Create a data collator to encode text and image pairs
#######################################################
def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["images"] for example in examples]
    if isinstance(model, LlavaForConditionalGeneration):
        # LLava1.5 does not support multiple images
        images = [image[0] for image in images]

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  #
    # Ignore the image token index in the loss computation (model specific)
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    labels[labels == image_token_id] = -100
    batch["labels"] = labels

    return batch

##############
# Load dataset
##############
dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft")

###################
# Configure trainer
###################
training_args = SFTConfig(
    output_dir="my-awesome-llama", 
    gradient_checkpointing=True,
    gradient_accumulation_steps=8,
    bf16=True,
    remove_unused_columns=False
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.tokenizer,
)

# Train!
trainer.train()

# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
    trainer.push_to_hub()
    if trainer.accelerator.is_main_process:
        processor.push_to_hub(training_args.hub_model_id)

You'll need to adjust the batch size for your hardware and will need to shard the model with ZeRO-3 for maximum efficiency.

Check out the full script here: https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py

Meta Llama org

This is awesome! Thanks @lewtun ! 🙏

Would you be open to also adding an example in llama-recipes here? I'm sure the community will find it useful!

@lewtun , Appreciate your efforts first!
For someone who just refers to the codes above, It seems better to add dataset_kwargs={'skip_prepare_dataset': True} parameter in training_args.
It's well said in full script, but one can miss it as I did..

Meta Llama org

Thanks @lewtun ! May I know if the TRL Trainer is freezing the text-decoder part i.e. only training the vision encoder?

Sign up or log in to comment