Fine-tune Llama Vision models with TRL 🚀
#31
by
lewtun
HF staff
- opened
Hello everyone, it's Lewis here from the TRL team 👋
We've added support for the Llama 3.2 Vision models to TRL's SFTTrainer
, so you can fine-tune them in under 80 lines of code like this:
import torch
from accelerate import Accelerator
from datasets import load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration
from trl import (
ModelConfig,
SFTConfig,
SFTTrainer
)
##########################
# Load model and processor
##########################
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.bfloat16)
#######################################################
# Create a data collator to encode text and image pairs
#######################################################
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"] for example in examples]
if isinstance(model, LlavaForConditionalGeneration):
# LLava1.5 does not support multiple images
images = [image[0] for image in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
##############
# Load dataset
##############
dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft")
###################
# Configure trainer
###################
training_args = SFTConfig(
output_dir="my-awesome-llama",
gradient_checkpointing=True,
gradient_accumulation_steps=8,
bf16=True,
remove_unused_columns=False
)
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=processor.tokenizer,
)
# Train!
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
if trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)
You'll need to adjust the batch size for your hardware and will need to shard the model with ZeRO-3 for maximum efficiency.
Check out the full script here: https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py
Thank You
Thanks @lewtun ! May I know if the TRL Trainer is freezing the text-decoder part i.e. only training the vision encoder?