Update train.py
Browse files
train.py
CHANGED
@@ -6,9 +6,7 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, Train
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
-
|
10 |
-
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
|
11 |
-
trainer.save_model(save_dir) # Save the model
|
12 |
# Load dataset
|
13 |
dataset = load_dataset("tatsu-lab/alpaca")
|
14 |
dataset["train"] = dataset["train"].select(range(2000))
|
@@ -71,6 +69,10 @@ trainer = Trainer(
|
|
71 |
eval_dataset=eval_dataset,
|
72 |
)
|
73 |
|
|
|
|
|
|
|
|
|
74 |
# Start fine-tuning
|
75 |
trainer.train()
|
76 |
|
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
+
|
|
|
|
|
10 |
# Load dataset
|
11 |
dataset = load_dataset("tatsu-lab/alpaca")
|
12 |
dataset["train"] = dataset["train"].select(range(2000))
|
|
|
69 |
eval_dataset=eval_dataset,
|
70 |
)
|
71 |
|
72 |
+
save_dir = "/tmp/t5-finetuned" # Use /tmp/, which is writable
|
73 |
+
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
|
74 |
+
trainer.save_model(save_dir) # Save the model
|
75 |
+
|
76 |
# Start fine-tuning
|
77 |
trainer.train()
|
78 |
|