fffiloni commited on
Commit
5e08dc0
·
verified ·
1 Parent(s): 5239a5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -24,6 +24,12 @@ if is_gpu_associated:
24
  else:
25
  which_gpu = "CPU"
26
 
 
 
 
 
 
 
27
  def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps):
28
 
29
  script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
@@ -203,8 +209,16 @@ with gr.Blocks(css=css) as demo:
203
  training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
204
  b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
205
  instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="[v42]")
 
 
206
  train_btn = gr.Button("Train B-LoRa")
207
  status = gr.Textbox(label="status")
 
 
 
 
 
 
208
 
209
  train_btn.click(
210
  fn = main,
 
24
  else:
25
  which_gpu = "CPU"
26
 
27
+ def change_training_setup(training_type):
28
+ if training_type == "style" :
29
+ return 1000, 500
30
+ elif training_type == "concept" :
31
+ return 2000, 1000
32
+
33
  def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps):
34
 
35
  script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
 
209
  training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
210
  b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
211
  instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="[v42]")
212
+ training_steps = gr.Number(label="Training steps", value=1000)
213
+ checkpoint_step = gr.Number(label="checkpoint step", visible=False, value=500)
214
  train_btn = gr.Button("Train B-LoRa")
215
  status = gr.Textbox(label="status")
216
+
217
+ training_type.changes(
218
+ fn = change_training_setup,
219
+ inputs = [training_type],
220
+ outputs = [training_steps, checkpoint_step]
221
+ )
222
 
223
  train_btn.click(
224
  fn = main,