Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -30,6 +30,12 @@ def change_training_setup(training_type):
|
|
30 |
elif training_type == "concept" :
|
31 |
return 2000, 1000
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps):
|
34 |
|
35 |
script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
|
@@ -104,6 +110,8 @@ def main(image_path, b_lora_trained_folder, instance_prompt, class_prompt, train
|
|
104 |
train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps)
|
105 |
|
106 |
your_username = api.whoami(token=hf_token)["name"]
|
|
|
|
|
107 |
|
108 |
return f"Done, your trained model has been stored in your models library: {your_username}/{b_lora_trained_folder}"
|
109 |
|
|
|
30 |
elif training_type == "concept" :
|
31 |
return 2000, 1000
|
32 |
|
33 |
+
def swap_hardware(hf_token, hardware="cpu-basic"):
|
34 |
+
hardware_url = f"https://huggingface.co/spaces/{os.environ['SPACE_ID']}/hardware"
|
35 |
+
headers = { "authorization" : f"Bearer {hf_token}"}
|
36 |
+
body = {'flavor': hardware}
|
37 |
+
requests.post(hardware_url, json = body, headers=headers)
|
38 |
+
|
39 |
def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps):
|
40 |
|
41 |
script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
|
|
|
110 |
train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps)
|
111 |
|
112 |
your_username = api.whoami(token=hf_token)["name"]
|
113 |
+
|
114 |
+
swap_hardware(hf_token, hardware="cpu-basic")
|
115 |
|
116 |
return f"Done, your trained model has been stored in your models library: {your_username}/{b_lora_trained_folder}"
|
117 |
|