import os import spaces os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') HF_TOKEN = os.environ['hf_token'] if 'hf_token' in os.environ else None import uuid import time import torch import numpy as np import gradio as gr import tempfile gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') os.makedirs(gradio_temp_dir, exist_ok=True) from threading import Thread # Phi3 Hijack from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel Phi3PreTrainedModel._supports_sdpa = True from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import CLIPTextModel, CLIPTokenizer from lib_omost.pipeline import StableDiffusionXLOmostPipeline from chat_interface import ChatInterface from transformers.generation.stopping_criteria import StoppingCriteriaList import lib_omost.canvas as omost_canvas # SDXL sdxl_name = 'SG161222/RealVisXL_V4.0' # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0' tokenizer = CLIPTokenizer.from_pretrained( sdxl_name, subfolder="tokenizer") tokenizer_2 = CLIPTokenizer.from_pretrained( sdxl_name, subfolder="tokenizer_2") text_encoder = CLIPTextModel.from_pretrained( sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto") text_encoder_2 = CLIPTextModel.from_pretrained( sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto") vae = AutoencoderKL.from_pretrained( sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") # bfloat16 vae unet = UNet2DConditionModel.from_pretrained( sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto") unet.set_attn_processor(AttnProcessor2_0()) vae.set_attn_processor(AttnProcessor2_0()) pipeline = StableDiffusionXLOmostPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, unet=unet, scheduler=None, # We completely give up diffusers sampling system and use A1111's method ) # LLM # model_name = 'lllyasviel/omost-phi-3-mini-128k' llm_name = 'lllyasviel/omost-llama-3-8b' # model_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b' llm_model = AutoModelForCausalLM.from_pretrained( llm_name, torch_dtype="auto", token=HF_TOKEN, device_map="auto", trust_remote_code=True, ) llm_tokenizer = AutoTokenizer.from_pretrained( llm_name, token=HF_TOKEN ) @torch.inference_mode() def pytorch2numpy(imgs): results = [] for x in imgs: y = x.movedim(0, -1) y = y * 127.5 + 127.5 y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) results.append(y) return results @torch.inference_mode() def numpy2pytorch(imgs): h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 h = h.movedim(-1, 1) return h def resize_without_crop(image, target_width, target_height): pil_image = Image.fromarray(image) resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) return np.array(resized_image) @spaces.GPU(duration=120) @torch.inference_mode() def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str: print('Chat begin:', message) time_stamp = time.time() np.random.seed(int(seed)) torch.manual_seed(int(seed)) conversation = [{"role": "system", "content": omost_canvas.system_prompt}] for user, assistant in history: if isinstance(user, str) and isinstance(assistant, str): if len(user) > 0 and len(assistant) > 0: conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) conversation.append({"role": "user", "content": message}) input_ids = llm_tokenizer.apply_chat_template( conversation, return_tensors="pt", add_generation_prompt=True).to(llm_model.device) streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) def interactive_stopping_criteria(*args, **kwargs) -> bool: if getattr(streamer, 'user_interrupted', False): print('User stopped generation:', message) return True else: return False stopping_criteria = StoppingCriteriaList([interactive_stopping_criteria]) def interrupter(): streamer.user_interrupted = True return generate_kwargs = dict( input_ids=input_ids, streamer=streamer, stopping_criteria=stopping_criteria, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, ) if temperature == 0: generate_kwargs['do_sample'] = False Thread(target=llm_model.generate, kwargs=generate_kwargs).start() outputs = [] for text in streamer: outputs.append(text) # print(outputs) yield "".join(outputs), None print(f'Chat end at {time.time() - time_stamp:.2f} seconds:', message) return @torch.inference_mode() def post_chat(history): canvas_outputs = None try: if history: history = [(user, assistant) for user, assistant in history if isinstance(user, str) and isinstance(assistant, str)] last_assistant = history[-1][1] if len(history) > 0 else None canvas = omost_canvas.Canvas.from_bot_response(last_assistant) canvas_outputs = canvas.process() except Exception as e: print('Last assistant response is not valid canvas:', e) return canvas_outputs, gr.update(visible=canvas_outputs is not None), gr.update(interactive=len(history) > 0) @spaces.GPU @torch.inference_mode() def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height, highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt): use_initial_latent = False eps = 0.05 image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64 rng = torch.Generator(unet.device).manual_seed(seed) positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, negative_prompt) if use_initial_latent: initial_latent = torch.from_numpy(canvas_outputs['initial_latent'])[None].movedim(-1, 1) / 127.5 - 1.0 initial_latent_blur = 40 initial_latent = torch.nn.functional.avg_pool2d( torch.nn.functional.pad(initial_latent, (initial_latent_blur,) * 4, mode='reflect'), kernel_size=(initial_latent_blur * 2 + 1,) * 2, stride=(1, 1)) initial_latent = torch.nn.functional.interpolate(initial_latent, (image_height, image_width)) initial_latent = initial_latent.to(dtype=vae.dtype, device=vae.device) initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor else: initial_latent = torch.zeros(size=(num_samples, 4, image_height // 8, image_width // 8), dtype=torch.float32) initial_latent = initial_latent.to(dtype=unet.dtype, device=unet.device) latents = pipeline( initial_latent=initial_latent, strength=1.0, num_inference_steps=int(steps), batch_size=num_samples, prompt_embeds=positive_cond, negative_prompt_embeds=negative_cond, pooled_prompt_embeds=positive_pooler, negative_pooled_prompt_embeds=negative_pooler, generator=rng, guidance_scale=float(cfg), ).images latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor pixels = vae.decode(latents).sample B, C, H, W = pixels.shape pixels = pytorch2numpy(pixels) if highres_scale > 1.0 + eps: pixels = [ resize_without_crop( image=p, target_width=int(round(W * highres_scale / 64.0) * 64), target_height=int(round(H * highres_scale / 64.0) * 64) ) for p in pixels ] pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype) latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor latents = latents.to(device=unet.device, dtype=unet.dtype) latents = pipeline( initial_latent=latents, strength=highres_denoise, num_inference_steps=highres_steps, batch_size=num_samples, prompt_embeds=positive_cond, negative_prompt_embeds=negative_cond, pooled_prompt_embeds=positive_pooler, negative_pooled_prompt_embeds=negative_pooler, generator=rng, guidance_scale=float(cfg), ).images latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor pixels = vae.decode(latents).sample pixels = pytorch2numpy(pixels) for i in range(len(pixels)): unique_hex = uuid.uuid4().hex image_path = os.path.join(gradio_temp_dir, f"{unique_hex}_{i}.png") image = Image.fromarray(pixels[i]) image.save(image_path) chatbot = chatbot + [(None, (image_path, 'image'))] return chatbot css = ''' code {white-space: pre-wrap !important;} .gradio-container {max-width: none !important;} .outer_parent {flex: 1;} .inner_parent {flex: 1;} footer {display: none !important; visibility: hidden !important;} .translucent {display: none !important; visibility: hidden !important;} ''' from gradio.themes.utils import colors with gr.Blocks( fill_height=True, css=css, theme=gr.themes.Default(primary_hue=colors.blue, secondary_hue=colors.cyan, neutral_hue=colors.gray) ) as demo: with gr.Row(elem_classes='outer_parent'): with gr.Column(scale=25): with gr.Row(): clear_btn = gr.Button("➕ New Chat", variant="secondary", size="sm", min_width=60) retry_btn = gr.Button("Retry", variant="secondary", size="sm", min_width=60, visible=False) undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, interactive=False) seed = gr.Number(label="Random Seed", value=123456, precision=0) with gr.Accordion(open=True, label='Language Model'): with gr.Group(): with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=2.0, step=0.01, value=0.6, label="Temperature") top_p = gr.Slider( minimum=0.0, maximum=1.0, step=0.01, value=0.9, label="Top P") max_new_tokens = gr.Slider( minimum=128, maximum=4096, step=1, value=4096, label="Max New Tokens") with gr.Accordion(open=True, label='Image Diffusion Model'): with gr.Group(): with gr.Row(): image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=896, step=64) image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=1152, step=64) with gr.Row(): num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1) steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1) with gr.Accordion(open=False, label='Advanced'): cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01) highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01) highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1) highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01) n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality') render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False) examples = gr.Dataset( samples=[ ['generate an image of the fierce battle of warriors and the dragon'], ['change the dragon to a dinosaur'] ], components=[gr.Textbox(visible=False)], label='Quick Prompts' ) with gr.Row(): gr.Markdown("Omost: converting LLM's coding capability to image compositing capability.") with gr.Row(): gr.Markdown("Local version (8GB VRAM): https://github.com/lllyasviel/Omost") # with gr.Row(): # gr.Markdown("Hint: You can [duplicate this space](https://huggingface.co./spaces/lllyasviel/Omost?duplicate=true) to your private account to bypass the waiting queue.") with gr.Column(scale=75, elem_classes='inner_parent'): canvas_state = gr.State(None) chatbot = gr.Chatbot(label='Omost', scale=1, show_copy_button=True, layout="panel", render=False) chatInterface = ChatInterface( fn=chat_fn, post_fn=post_chat, post_fn_kwargs=dict(inputs=[chatbot], outputs=[canvas_state, render_button, undo_btn]), pre_fn=lambda: gr.update(visible=False), pre_fn_kwargs=dict(outputs=[render_button]), chatbot=chatbot, retry_btn=retry_btn, undo_btn=undo_btn, clear_btn=clear_btn, additional_inputs=[seed, temperature, top_p, max_new_tokens], examples=examples, show_stop_button=False ) render_button.click( fn=diffusion_fn, inputs=[ chatInterface.chatbot, canvas_state, num_samples, seed, image_width, image_height, highres_scale, steps, cfg, highres_steps, highres_denoise, n_prompt ], outputs=[chatInterface.chatbot]).then( fn=lambda x: x, inputs=[ chatInterface.chatbot ], outputs=[chatInterface.chatbot_state]) if __name__ == "__main__": demo.queue().launch()