|
import gradio as gr |
|
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor |
|
import spaces |
|
import torch |
|
import re |
|
from PIL import Image |
|
|
|
model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval() |
|
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner") |
|
|
|
def modify_caption(caption: str) -> str: |
|
""" |
|
Removes specific prefixes from captions. |
|
Args: |
|
caption (str): A string containing a caption. |
|
Returns: |
|
str: The caption with the prefix removed if it was present. |
|
""" |
|
prefix_substrings = [ |
|
('captured from ', ''), |
|
('captured at ', '') |
|
] |
|
|
|
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) |
|
replacers = {opening: replacer for opening, replacer in prefix_substrings} |
|
|
|
def replace_fn(match): |
|
return replacers[match.group(0)] |
|
|
|
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) |
|
|
|
def create_captions_rich(files): |
|
captions = [] |
|
prompt = "caption en" |
|
|
|
for file_path in files: |
|
try: |
|
image = Image.open(file_path.name) |
|
except Exception as e: |
|
captions.append(f"Error opening image: {e}") |
|
continue |
|
|
|
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu") |
|
input_len = model_inputs["input_ids"].shape[-1] |
|
|
|
try: |
|
with torch.no_grad(): |
|
generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False) |
|
generation = generation[0][input_len:] |
|
decoded = processor.decode(generation, skip_special_tokens=True) |
|
modified_caption = modify_caption(decoded) |
|
captions.append(modified_caption) |
|
except Exception as e: |
|
captions.append(f"Error generating caption: {e}") |
|
|
|
return "\n".join(captions) |
|
|
|
css = """ |
|
#mkd { |
|
height: 500px; |
|
overflow: auto; |
|
border: 16px solid #ccc; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML("<h1><center>Fine-tuned PaliGemma for SD3 Image Guided Prompt Generation.<center><h1>") |
|
|
|
with gr.Tab(label="Image to Prompt for SD3"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_files = gr.Files(label="Input Images") |
|
submit_btn = gr.Button(value="Start") |
|
outputs = gr.Textbox(label="Prompts", lines=10, interactive=False) |
|
|
|
submit_btn.click(create_captions_rich, inputs=[input_files], outputs=[outputs]) |
|
|
|
demo.launch(debug=True) |