cconsti commited on
Commit
3d82de7
·
verified ·
1 Parent(s): 16f00d1

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +5 -3
train.py CHANGED
@@ -6,9 +6,7 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, Train
6
  os.environ["HF_HOME"] = "/app/hf_cache"
7
  os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
8
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
9
- save_dir = "/tmp/t5-finetuned" # Use /tmp/, which is writable
10
- os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
11
- trainer.save_model(save_dir) # Save the model
12
  # Load dataset
13
  dataset = load_dataset("tatsu-lab/alpaca")
14
  dataset["train"] = dataset["train"].select(range(2000))
@@ -71,6 +69,10 @@ trainer = Trainer(
71
  eval_dataset=eval_dataset,
72
  )
73
 
 
 
 
 
74
  # Start fine-tuning
75
  trainer.train()
76
 
 
6
  os.environ["HF_HOME"] = "/app/hf_cache"
7
  os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
8
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
9
+
 
 
10
  # Load dataset
11
  dataset = load_dataset("tatsu-lab/alpaca")
12
  dataset["train"] = dataset["train"].select(range(2000))
 
69
  eval_dataset=eval_dataset,
70
  )
71
 
72
+ save_dir = "/tmp/t5-finetuned" # Use /tmp/, which is writable
73
+ os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
74
+ trainer.save_model(save_dir) # Save the model
75
+
76
  # Start fine-tuning
77
  trainer.train()
78