import gradio as gr import spaces import torch from loadimg import load_img from torchvision import transforms from transformers import AutoModelForImageSegmentation from diffusers import FluxFillPipeline from PIL import Image, ImageOps torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cuda") transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to("cuda") def prepare_image_and_mask( image, padding_top=0, padding_bottom=0, padding_left=0, padding_right=0, ): image = load_img(image).convert("RGB") # expand image (left,top,right,bottom) background = ImageOps.expand( image, border=(padding_left, padding_top, padding_right, padding_bottom), fill="white", ) mask = Image.new("RGB", image.size, "black") mask = ImageOps.expand( mask, border=(padding_left, padding_top, padding_right, padding_bottom), fill="white", ) return background, mask def inpaint( image, padding_top=0, padding_bottom=0, padding_left=0, padding_right=0, prompt="", ): background, mask = prepare_image_and_mask( image, padding_top, padding_bottom, padding_left, padding_right ) result = pipe( prompt=prompt, height=background.height, width=background.width, image=background, mask_image=mask, num_inference_steps=28, guidance_scale=30, ).images[0] result = result.convert("RGBA") return result def rmbg(image, url): if image is None: image = url image = load_img(image).convert("RGB") image_size = image.size input_images = transform_image(image).unsqueeze(0).to("cuda") # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image @spaces.GPU def main(*args, progress=gr.Progress(track_tqdm=True)): if len(args) == 2: return rmbg(*args) else: return inpaint(*args) rmbg_tab = gr.Interface( fn=main, inputs=["image", "text"], outputs=["image"], api_name="rmbg" ) outpaint_tab = gr.Interface( fn=main, inputs=[ "image", gr.Number(label="padding top"), gr.Number(label="padding bottom"), gr.Number(label="padding left"), gr.Number(label="padding right"), gr.Text(label="prompt"), ], outputs=["image"], api_name="outpainting", ) demo = gr.TabbedInterface( [rmbg_tab, outpaint_tab], ["remove background", "outpainting"], title="Utilities that require GPU", ) demo.launch()