trial1 / train.py
cconsti's picture
Update train.py
2328b17 verified
import os
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
# Ensure Hugging Face cache directory is writable
os.environ["HF_HOME"] = "/app/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
# Load dataset
dataset = load_dataset("tatsu-lab/alpaca")
dataset["train"] = dataset["train"].select(range(2000))
# Check dataset structure
print("Dataset splits available:", dataset)
print("Sample row:", dataset["train"][0])
# If no 'test' split exists, create one
if "test" not in dataset:
dataset = dataset["train"].train_test_split(test_size=0.1)
# Load tokenizer & model
model_name = "t5-large"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.gradient_checkpointing_disable()
# Define tokenization function
def tokenize_function(examples):
inputs = examples["input"] # Ensure this matches dataset key
targets = examples["output"]
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# Tokenize dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# Assign train & eval datasets
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["test"]
print("Dataset successfully split and tokenized.")
# Define training arguments
training_args = TrainingArguments(
output_dir="/tmp/results", # Use /tmp/ to avoid permission errors
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
evaluation_strategy="steps",
save_steps=500,
eval_steps=500,
logging_dir="/tmp/logs", # Avoid writing to restricted directories
logging_steps=100,
save_total_limit=2,
fp16=True
)
# Set up Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
save_dir = "/tmp/t5-finetuned" # Use /tmp/, which is writable
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
trainer.save_model(save_dir) # Save the model
# Start fine-tuning
trainer.train()
print("Fine-tuning complete!")
# Save model locally
trainer.save_model("./t5-finetuned")
print("Model saved successfully!")