Spaces:
Running
Running
import gradio as gr | |
import subprocess | |
import os | |
import logging | |
from pathlib import Path | |
import spaces | |
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message): | |
# Define a fixed output path | |
outpath = Path('/tmp') | |
# Construct the command to run hf_merge.py | |
command = [ | |
"python3", "hf_merge.py", | |
base_model, | |
model_to_merge, | |
"-p", str(weight_drop_prob), | |
"-lambda", str(scaling_factor), | |
"--token", token, | |
"--repo", repo_name, | |
"--commit-message", commit_message, | |
"-U" | |
] | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
log_output = "" | |
# Run the command and capture the output | |
result = subprocess.run(command, capture_output=True, text=True) | |
# Log the output | |
log_output += result.stdout + "\n" | |
log_output += result.stderr + "\n" | |
logging.info(result.stdout) | |
logging.error(result.stderr) | |
# Check if the merge was successful | |
if result.returncode != 0: | |
return None, f"Error in merging models: {result.stderr}", log_output | |
# Assuming the script handles the upload and returns the repo URL | |
repo_url = f"https://huggingface.co./{repo_name}" | |
return repo_url, "Model merged and uploaded successfully!", log_output | |
# Define the Gradio interface | |
with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo: | |
gr.Markdown("# SuperMario Safetensors Merger") | |
gr.Markdown("Combine any two models using a Super Mario merge(DARE)") | |
gr.Markdown("Based on: https://github.com/martyn/safetensors-merge-supermario") | |
gr.Markdown("Works with:") | |
gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)") | |
gr.Markdown("* LLMs (Mistral, Llama, etc) (also works with Llava, Visison models) ") | |
gr.Markdown("* LoRas (must be same size)") | |
gr.Markdown("* Any two homologous models") | |
with gr.Column(): | |
with gr.Row(): | |
token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1) | |
with gr.Row(): | |
base_model = gr.Textbox(label="Base Model", placeholder="meta-llama/Llama-3.2-11B-Vision-Instruct", info="Safetensors format") | |
with gr.Row(): | |
model_to_merge = gr.Textbox(label="Merge Model", placeholder="Qwen/Qwen2.5-Coder-7B-Instruct", info="Safetensors or .bin") | |
with gr.Row(): | |
repo_name = gr.Textbox(label="New Model", placeholder="Llama-Qwen-Vision_Instruct", info="your-username/new-model-name", value="", max_lines=1) | |
with gr.Row(): | |
scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor") | |
with gr.Row(): | |
weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability") | |
with gr.Row(): | |
commit_message = gr.Textbox(label="Commit Message", value="Upload merged model", max_lines=1) | |
progress = gr.Progress() | |
repo_url = gr.Markdown(label="Repository URL") | |
output = gr.Textbox(label="Output") | |
gr.Button("Merge").click( | |
merge_and_upload, | |
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message], | |
outputs=[repo_url, output] | |
) | |
demo.launch() |