RuntimeWarning: invalid value encountered in cast images = (images * 255).round().astype("uint8")
#11
by
balalala
- opened
I changed torch.bfloat16 in the sample code to torch.float16,Then I ran the code.
output image is pure black. And warning,
diffusers/lib/python3.10/site-packages/diffusers/image_processor.py:112: RuntimeWarning: invalid value encountered in cast images = (images * 255).round().astype("uint8")
Is there a solution to run on torch.float16?
my code is as follows
import torch from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
prompt = "A cat holding a sign that says hello world"
image = pipe( prompt, guidance_scale=0.0, output_type ="pil",
num_inference_steps=4, max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0) ).images[0]
image.save("flux-schnell.png")
made a PR to fix this issue: https://github.com/huggingface/diffusers/pull/9097if it isn't merged yet, you can install my diffusers fork:!pip install -U git+https://github.com/latentCall145/diffusers.git@flux-fp16-fix
edit: it's merged into the main diffusers fork:!pip install -U git+https://github.com/huggingface/diffusers.git
Diffusers Code
from diffusers import FluxPipeline
import torch
ckpt_id = "black-forest-labs/FLUX.1-schnell"
prompt = [
"an astronaut riding a horse on mars",
# more prompts here
]
height, width = 1024, 1024
# denoising
pipe = FluxPipeline.from_pretrained(
ckpt_id,
torch_dtype=torch.bfloat16, # setting this to torch.float16 is much slower than casting later
)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.to(torch.half) # now we cast (enable_sequential_cpu_offload allows weights to be casted to fp16 as needed instead of all weights at once, saving ~30 GB CPU RAM for this model)
image = pipe(
prompt,
num_inference_steps=1,
guidance_scale=0.0,
height=height,
width=width,
).images[0]
import matplotlib.pyplot as plt
plt.imshow(image)
plt.show()