Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -62,7 +62,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
|
|
62 |
except subprocess.CalledProcessError as e:
|
63 |
print(f"An error occurred: {e}")
|
64 |
|
65 |
-
def main(image_path, b_lora_trained_folder, instance_prompt):
|
66 |
|
67 |
if is_shared_ui:
|
68 |
raise gr.Error("This Space only works in duplicated instances")
|
@@ -86,9 +86,13 @@ def main(image_path, b_lora_trained_folder, instance_prompt):
|
|
86 |
|
87 |
shutil.copy(image_path, local_dir)
|
88 |
print(f"source image has been copied in {local_dir} directory")
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
92 |
|
93 |
train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
|
94 |
|
@@ -196,6 +200,7 @@ with gr.Blocks(css=css) as demo:
|
|
196 |
image = gr.Image(label="Image Reference", sources=["upload"], type="filepath")
|
197 |
|
198 |
with gr.Column():
|
|
|
199 |
b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
|
200 |
instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="[v42]")
|
201 |
train_btn = gr.Button("Train B-LoRa")
|
|
|
62 |
except subprocess.CalledProcessError as e:
|
63 |
print(f"An error occurred: {e}")
|
64 |
|
65 |
+
def main(image_path, b_lora_trained_folder, instance_prompt, training_type):
|
66 |
|
67 |
if is_shared_ui:
|
68 |
raise gr.Error("This Space only works in duplicated instances")
|
|
|
86 |
|
87 |
shutil.copy(image_path, local_dir)
|
88 |
print(f"source image has been copied in {local_dir} directory")
|
89 |
+
|
90 |
+
if training_type == "style":
|
91 |
+
max_train_steps = 1000
|
92 |
+
checkpoint_steps = 500
|
93 |
+
elif training_type == "concept" :
|
94 |
+
max_train_steps = 2000
|
95 |
+
checkpoint_steps = 1000
|
96 |
|
97 |
train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
|
98 |
|
|
|
200 |
image = gr.Image(label="Image Reference", sources=["upload"], type="filepath")
|
201 |
|
202 |
with gr.Column():
|
203 |
+
training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
|
204 |
b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
|
205 |
instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="[v42]")
|
206 |
train_btn = gr.Button("Train B-LoRa")
|