Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ def change_training_setup(training_type):
|
|
30 |
elif training_type == "concept" :
|
31 |
return 2000, 1000
|
32 |
|
33 |
-
def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps):
|
34 |
|
35 |
script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
|
36 |
|
@@ -42,6 +42,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
|
|
42 |
f"--instance_data_dir={instance_data_dir}",
|
43 |
f"--output_dir={b_lora_trained_folder}",
|
44 |
f"--instance_prompt='{instance_prompt}'",
|
|
|
45 |
#f"--validation_prompt=a teddy bear in {instance_prompt} style",
|
46 |
"--num_validation_images=1",
|
47 |
"--validation_epochs=500",
|
@@ -68,7 +69,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
|
|
68 |
except subprocess.CalledProcessError as e:
|
69 |
print(f"An error occurred: {e}")
|
70 |
|
71 |
-
def main(image_path, b_lora_trained_folder, instance_prompt, training_type, training_steps):
|
72 |
|
73 |
if is_shared_ui:
|
74 |
raise gr.Error("This Space only works in duplicated instances")
|
@@ -100,7 +101,7 @@ def main(image_path, b_lora_trained_folder, instance_prompt, training_type, trai
|
|
100 |
|
101 |
max_train_steps = training_steps
|
102 |
|
103 |
-
train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
|
104 |
|
105 |
your_username = api.whoami(token=hf_token)["name"]
|
106 |
|
@@ -208,7 +209,9 @@ with gr.Blocks(css=css) as demo:
|
|
208 |
with gr.Column():
|
209 |
training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
|
210 |
b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
|
211 |
-
|
|
|
|
|
212 |
training_steps = gr.Number(label="Training steps", value=1000, interactive=False)
|
213 |
checkpoint_step = gr.Number(label="checkpoint step", visible=False, value=500)
|
214 |
train_btn = gr.Button("Train B-LoRa")
|
@@ -222,7 +225,7 @@ with gr.Blocks(css=css) as demo:
|
|
222 |
|
223 |
train_btn.click(
|
224 |
fn = main,
|
225 |
-
inputs = [image, b_lora_name, instance_prompt, training_type, training_steps],
|
226 |
outputs = [status]
|
227 |
)
|
228 |
|
|
|
30 |
elif training_type == "concept" :
|
31 |
return 2000, 1000
|
32 |
|
33 |
+
def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps):
|
34 |
|
35 |
script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
|
36 |
|
|
|
42 |
f"--instance_data_dir={instance_data_dir}",
|
43 |
f"--output_dir={b_lora_trained_folder}",
|
44 |
f"--instance_prompt='{instance_prompt}'",
|
45 |
+
f"--class_prompt={class_prompt}",
|
46 |
#f"--validation_prompt=a teddy bear in {instance_prompt} style",
|
47 |
"--num_validation_images=1",
|
48 |
"--validation_epochs=500",
|
|
|
69 |
except subprocess.CalledProcessError as e:
|
70 |
print(f"An error occurred: {e}")
|
71 |
|
72 |
+
def main(image_path, b_lora_trained_folder, instance_prompt, class_prompt, training_type, training_steps):
|
73 |
|
74 |
if is_shared_ui:
|
75 |
raise gr.Error("This Space only works in duplicated instances")
|
|
|
101 |
|
102 |
max_train_steps = training_steps
|
103 |
|
104 |
+
train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps)
|
105 |
|
106 |
your_username = api.whoami(token=hf_token)["name"]
|
107 |
|
|
|
209 |
with gr.Column():
|
210 |
training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
|
211 |
b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
|
212 |
+
with gr.Row():
|
213 |
+
instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="A [v42] <class_prompt>")
|
214 |
+
class_prompt = gr.Textbox(label="Specify class prompt", placeholder="style | person | dog ")
|
215 |
training_steps = gr.Number(label="Training steps", value=1000, interactive=False)
|
216 |
checkpoint_step = gr.Number(label="checkpoint step", visible=False, value=500)
|
217 |
train_btn = gr.Button("Train B-LoRa")
|
|
|
225 |
|
226 |
train_btn.click(
|
227 |
fn = main,
|
228 |
+
inputs = [image, b_lora_name, instance_prompt, class_prompt, training_type, training_steps],
|
229 |
outputs = [status]
|
230 |
)
|
231 |
|