fffiloni commited on
Commit
e6047f1
·
verified ·
1 Parent(s): c8d8aaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -30,7 +30,7 @@ def change_training_setup(training_type):
30
  elif training_type == "concept" :
31
  return 2000, 1000
32
 
33
- def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps):
34
 
35
  script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
36
 
@@ -42,6 +42,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
42
  f"--instance_data_dir={instance_data_dir}",
43
  f"--output_dir={b_lora_trained_folder}",
44
  f"--instance_prompt='{instance_prompt}'",
 
45
  #f"--validation_prompt=a teddy bear in {instance_prompt} style",
46
  "--num_validation_images=1",
47
  "--validation_epochs=500",
@@ -68,7 +69,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
68
  except subprocess.CalledProcessError as e:
69
  print(f"An error occurred: {e}")
70
 
71
- def main(image_path, b_lora_trained_folder, instance_prompt, training_type, training_steps):
72
 
73
  if is_shared_ui:
74
  raise gr.Error("This Space only works in duplicated instances")
@@ -100,7 +101,7 @@ def main(image_path, b_lora_trained_folder, instance_prompt, training_type, trai
100
 
101
  max_train_steps = training_steps
102
 
103
- train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
104
 
105
  your_username = api.whoami(token=hf_token)["name"]
106
 
@@ -208,7 +209,9 @@ with gr.Blocks(css=css) as demo:
208
  with gr.Column():
209
  training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
210
  b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
211
- instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="A [v42]")
 
 
212
  training_steps = gr.Number(label="Training steps", value=1000, interactive=False)
213
  checkpoint_step = gr.Number(label="checkpoint step", visible=False, value=500)
214
  train_btn = gr.Button("Train B-LoRa")
@@ -222,7 +225,7 @@ with gr.Blocks(css=css) as demo:
222
 
223
  train_btn.click(
224
  fn = main,
225
- inputs = [image, b_lora_name, instance_prompt, training_type, training_steps],
226
  outputs = [status]
227
  )
228
 
 
30
  elif training_type == "concept" :
31
  return 2000, 1000
32
 
33
+ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps):
34
 
35
  script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
36
 
 
42
  f"--instance_data_dir={instance_data_dir}",
43
  f"--output_dir={b_lora_trained_folder}",
44
  f"--instance_prompt='{instance_prompt}'",
45
+ f"--class_prompt={class_prompt}",
46
  #f"--validation_prompt=a teddy bear in {instance_prompt} style",
47
  "--num_validation_images=1",
48
  "--validation_epochs=500",
 
69
  except subprocess.CalledProcessError as e:
70
  print(f"An error occurred: {e}")
71
 
72
+ def main(image_path, b_lora_trained_folder, instance_prompt, class_prompt, training_type, training_steps):
73
 
74
  if is_shared_ui:
75
  raise gr.Error("This Space only works in duplicated instances")
 
101
 
102
  max_train_steps = training_steps
103
 
104
+ train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps)
105
 
106
  your_username = api.whoami(token=hf_token)["name"]
107
 
 
209
  with gr.Column():
210
  training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
211
  b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
212
+ with gr.Row():
213
+ instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="A [v42] <class_prompt>")
214
+ class_prompt = gr.Textbox(label="Specify class prompt", placeholder="style | person | dog ")
215
  training_steps = gr.Number(label="Training steps", value=1000, interactive=False)
216
  checkpoint_step = gr.Number(label="checkpoint step", visible=False, value=500)
217
  train_btn = gr.Button("Train B-LoRa")
 
225
 
226
  train_btn.click(
227
  fn = main,
228
+ inputs = [image, b_lora_name, instance_prompt, class_prompt, training_type, training_steps],
229
  outputs = [status]
230
  )
231