fffiloni commited on
Commit
5239a5d
·
verified ·
1 Parent(s): 65ba024

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -62,7 +62,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
62
  except subprocess.CalledProcessError as e:
63
  print(f"An error occurred: {e}")
64
 
65
- def main(image_path, b_lora_trained_folder, instance_prompt):
66
 
67
  if is_shared_ui:
68
  raise gr.Error("This Space only works in duplicated instances")
@@ -86,9 +86,13 @@ def main(image_path, b_lora_trained_folder, instance_prompt):
86
 
87
  shutil.copy(image_path, local_dir)
88
  print(f"source image has been copied in {local_dir} directory")
89
-
90
- max_train_steps = 2000
91
- checkpoint_steps = 1000
 
 
 
 
92
 
93
  train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
94
 
@@ -196,6 +200,7 @@ with gr.Blocks(css=css) as demo:
196
  image = gr.Image(label="Image Reference", sources=["upload"], type="filepath")
197
 
198
  with gr.Column():
 
199
  b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
200
  instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="[v42]")
201
  train_btn = gr.Button("Train B-LoRa")
 
62
  except subprocess.CalledProcessError as e:
63
  print(f"An error occurred: {e}")
64
 
65
+ def main(image_path, b_lora_trained_folder, instance_prompt, training_type):
66
 
67
  if is_shared_ui:
68
  raise gr.Error("This Space only works in duplicated instances")
 
86
 
87
  shutil.copy(image_path, local_dir)
88
  print(f"source image has been copied in {local_dir} directory")
89
+
90
+ if training_type == "style":
91
+ max_train_steps = 1000
92
+ checkpoint_steps = 500
93
+ elif training_type == "concept" :
94
+ max_train_steps = 2000
95
+ checkpoint_steps = 1000
96
 
97
  train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
98
 
 
200
  image = gr.Image(label="Image Reference", sources=["upload"], type="filepath")
201
 
202
  with gr.Column():
203
+ training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
204
  b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
205
  instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="[v42]")
206
  train_btn = gr.Button("Train B-LoRa")