--- language: - en library_name: diffusers license: other license_name: flux-1-dev-non-commercial-license license_link: LICENSE.md base_model: - black-forest-labs/FLUX.1-dev - black-forest-labs/FLUX.1-schnell base_model_relation: merge pipeline_tag: text-to-image --- # **FLUX.1-Merged** This repository provides the merged params for [`black-forest-labs/FLUX.1-dev`](https://huggingface.co./black-forest-labs/FLUX.1-dev) and [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co./black-forest-labs/FLUX.1-schnell). # **Merge & Upload** ```python from diffusers import FluxTransformer2DModel from huggingface_hub import snapshot_download from huggingface_hub import upload_folder from accelerate import init_empty_weights from diffusers.models.model_loading_utils import load_model_dict_into_meta import safetensors.torch import glob import torch # Initialize the model with empty weights with init_empty_weights(): config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer") model = FluxTransformer2DModel.from_config(config) # Download the model checkpoints dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*") schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*") # Get the paths to the model shards dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors")) schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors")) # Merge the state dictionaries merged_state_dict = {} guidance_state_dict = {} for i in range(len(dev_shards)): state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i]) state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i]) keys = list(state_dict_dev_temp.keys()) for k in keys: if "guidance" not in k: merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2 else: guidance_state_dict[k] = state_dict_dev_temp.pop(k) if len(state_dict_dev_temp) > 0: raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.") if len(state_dict_schnell_temp) > 0: raise ValueError(f"There should not be any residue but got: {list(state_dict_schnell_temp.keys())}.") # Update the merged state dictionary with the guidance state dictionary merged_state_dict.update(guidance_state_dict) # Load the merged state dictionary into the model load_model_dict_into_meta(model, merged_state_dict) # Save the merged model model.to(torch.bfloat16).save_pretrained("transformer") # Upload the merged model to the Hugging Face Hub upload_folder( repo_id="prithivMLmods/Flux.1-Merged", # Replace with your Hugging Face username and desired repo name folder_path="transformer", path_in_repo="transformer", ) ``` # **Inference** ```python from diffusers import FluxPipeline import torch pipeline = FluxPipeline.from_pretrained( "prithivMLmods/Flux.1-Merged", torch_dtype=torch.bfloat16 ).to("cuda") image = pipeline( prompt="a tiny astronaut hatching from an egg on the moon", guidance_scale=3.5, num_inference_steps=4, height=880, width=1184, max_sequence_length=512, generator=torch.manual_seed(0), ).images[0] image.save("merged_flux.png") ```