File size: 3,362 Bytes
7db1360 e440740 7db1360 472834d d70bc68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
---
language:
- en
library_name: diffusers
license: other
license_name: flux-1-dev-non-commercial-license
license_link: LICENSE.md
base_model:
- black-forest-labs/FLUX.1-dev
- black-forest-labs/FLUX.1-schnell
base_model_relation: merge
pipeline_tag: text-to-image
---
# **FLUX.1-Merged**
This repository provides the merged params for [`black-forest-labs/FLUX.1-dev`](https://huggingface.co./black-forest-labs/FLUX.1-dev)
and [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co./black-forest-labs/FLUX.1-schnell).
# **Merge & Upload**
```python
from diffusers import FluxTransformer2DModel
from huggingface_hub import snapshot_download
from huggingface_hub import upload_folder
from accelerate import init_empty_weights
from diffusers.models.model_loading_utils import load_model_dict_into_meta
import safetensors.torch
import glob
import torch
# Initialize the model with empty weights
with init_empty_weights():
config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config)
# Download the model checkpoints
dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")
# Get the paths to the model shards
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
# Merge the state dictionaries
merged_state_dict = {}
guidance_state_dict = {}
for i in range(len(dev_shards)):
state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])
keys = list(state_dict_dev_temp.keys())
for k in keys:
if "guidance" not in k:
merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2
else:
guidance_state_dict[k] = state_dict_dev_temp.pop(k)
if len(state_dict_dev_temp) > 0:
raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
if len(state_dict_schnell_temp) > 0:
raise ValueError(f"There should not be any residue but got: {list(state_dict_schnell_temp.keys())}.")
# Update the merged state dictionary with the guidance state dictionary
merged_state_dict.update(guidance_state_dict)
# Load the merged state dictionary into the model
load_model_dict_into_meta(model, merged_state_dict)
# Save the merged model
model.to(torch.bfloat16).save_pretrained("transformer")
# Upload the merged model to the Hugging Face Hub
upload_folder(
repo_id="prithivMLmods/Flux.1-Merged", # Replace with your Hugging Face username and desired repo name
folder_path="transformer",
path_in_repo="transformer",
)
```
# **Inference**
```python
from diffusers import FluxPipeline
import torch
pipeline = FluxPipeline.from_pretrained(
"prithivMLmods/Flux.1-Merged", torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
prompt="a tiny astronaut hatching from an egg on the moon",
guidance_scale=3.5,
num_inference_steps=4,
height=880,
width=1184,
max_sequence_length=512,
generator=torch.manual_seed(0),
).images[0]
image.save("merged_flux.png")
``` |