fffiloni commited on
Commit
c8d8aaa
·
verified ·
1 Parent(s): cacb85c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -68,7 +68,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
68
  except subprocess.CalledProcessError as e:
69
  print(f"An error occurred: {e}")
70
 
71
- def main(image_path, b_lora_trained_folder, instance_prompt, training_type):
72
 
73
  if is_shared_ui:
74
  raise gr.Error("This Space only works in duplicated instances")
@@ -94,11 +94,11 @@ def main(image_path, b_lora_trained_folder, instance_prompt, training_type):
94
  print(f"source image has been copied in {local_dir} directory")
95
 
96
  if training_type == "style":
97
- max_train_steps = 1000
98
  checkpoint_steps = 500
99
  elif training_type == "concept" :
100
- max_train_steps = 2000
101
  checkpoint_steps = 1000
 
 
102
 
103
  train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
104
 
@@ -222,7 +222,7 @@ with gr.Blocks(css=css) as demo:
222
 
223
  train_btn.click(
224
  fn = main,
225
- inputs = [image, b_lora_name, instance_prompt, training_type],
226
  outputs = [status]
227
  )
228
 
 
68
  except subprocess.CalledProcessError as e:
69
  print(f"An error occurred: {e}")
70
 
71
+ def main(image_path, b_lora_trained_folder, instance_prompt, training_type, training_steps):
72
 
73
  if is_shared_ui:
74
  raise gr.Error("This Space only works in duplicated instances")
 
94
  print(f"source image has been copied in {local_dir} directory")
95
 
96
  if training_type == "style":
 
97
  checkpoint_steps = 500
98
  elif training_type == "concept" :
 
99
  checkpoint_steps = 1000
100
+
101
+ max_train_steps = training_steps
102
 
103
  train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
104
 
 
222
 
223
  train_btn.click(
224
  fn = main,
225
+ inputs = [image, b_lora_name, instance_prompt, training_type, training_steps],
226
  outputs = [status]
227
  )
228