Safetensors
zamba2

Fine-tuning Zamba2-7B Model: Step-by-Step Guide

#1
by ssmits - opened

This guide will walk you through the process of fine-tuning the Zamba2-7B model. Make sure you have sufficient GPU memory as this is a 7B parameter model. I've tested the 1.2B model, which needed about ~20GB, so 2x H100 or 4x RTX4090 would probably suffice. Adjust the batch size and gradient accumulation steps based on your available VRAM.

Tested Environment:

  • Vast.ai cloud instance (template link)
  • CUDA 11.8
  • Python 3.10
  • PyTorch Image: pytorch/pytorch:2.3.1-cuda11.8-cudnn8-devel
  • Selected for extra stability

1. Setup Environment

First, clone the repository and set up the environment:

git clone https://github.com/Zyphra/transformers_zamba2.git
cd transformers_zamba2

# Create and activate virtual environment
python -m venv venv
source venv/bin/activate  # For Windows use: venv\Scripts\activate

# Install dependencies
pip install -e .
pip install accelerate datasets

2. Basic Inference Test

Let's first test if the model loads correctly:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
model = AutoModelForCausalLM.from_pretrained(
    "Zyphra/Zamba2-7B", 
    device_map="cuda", 
    torch_dtype=torch.bfloat16
)

# Test generation
input_text = "What factors contributed to the fall of the Roman Empire?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))

3. Fine-tuning Setup

Here's the complete fine-tuning script with detailed configurations:

CONTEXT_WINDOW = 1024

from transformers import (
    AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
    Trainer, DataCollatorForLanguageModeling
)
import torch
from datasets import Dataset

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # Better for inference

# Initialize model
model = AutoModelForCausalLM.from_pretrained(
    "Zyphra/Zamba2-7B",
    torch_dtype=torch.bfloat16,
    device_map="auto"  # Handles multi-GPU/CPU mapping
)
model.config.pad_token_id = tokenizer.pad_token_id

# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding=True,
        truncation=True,
        max_length=1024,
        return_tensors=None
    )

# Prepare training data
train_texts = [
    "What factors contributed to the fall of the Roman Empire?",
    # Add your training examples here
]

dataset = Dataset.from_dict({"text": train_texts})
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset.column_names
)

# Training configuration
training_args = TrainingArguments(
    output_dir="./zamba2-finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    learning_rate=2e-5,
    weight_decay=0.01,
    fp16=False,
    bf16=True,
    gradient_accumulation_steps=16
)

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Custom trainer wrapper for device mapping
class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = model
        
    def _move_model_to_device(self, model, device):
        pass  # Model already mapped to devices

# Initialize trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)

# Train and save
trainer.train()
model.save_pretrained("./zamba2-finetuned-final")
tokenizer.save_pretrained("./zamba2-finetuned-final")

Important Notes:

  1. Hardware Requirements:

    • Recommended: GPU with at least 24GB VRAM for CONTEXT_WINDOW = 1024 (>24GB for 2048, did not test multi-GPU yet).
    • The script uses bfloat16 precision to reduce memory usage
  2. Training Configuration:

    • Context window: 1024 tokens
    • Learning rate: 2e-5
    • Weight decay: 0.01
    • Gradient accumulation: 16 steps
    • Training epochs: 3
  3. Customization:

    • Add your training examples to the train_texts list
    • Adjust training_args parameters based on your needs
    • Modify max_length in tokenization if needed
  4. Output:

    • The fine-tuned model will be saved in ./zamba2-finetuned-final
    • Checkpoints during training are saved in ./zamba2-finetuned

Feel free to ask if you have any questions about the process.

Sign up or log in to comment