metadata
language:
- zh
pipeline_tag: image-to-text
tags:
- vit
- gpt
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image
import pathlib
import pandas as pd
import numpy as np
from IPython.core.display import HTML
import os
import requests
class Image2Caption(object):
def __init__(self ,model_path = "nlpconnect/vit-gpt2-image-captioning",
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
overwrite_encoder_checkpoint_path = None,
overwrite_token_model_path = None
):
assert type(overwrite_token_model_path) == type("") or overwrite_token_model_path is None
assert type(overwrite_encoder_checkpoint_path) == type("") or overwrite_encoder_checkpoint_path is None
if overwrite_token_model_path is None:
overwrite_token_model_path = model_path
if overwrite_encoder_checkpoint_path is None:
overwrite_encoder_checkpoint_path = model_path
self.device = device
self.model = VisionEncoderDecoderModel.from_pretrained(model_path)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(overwrite_encoder_checkpoint_path)
self.tokenizer = AutoTokenizer.from_pretrained(overwrite_token_model_path)
self.model = self.model.to(self.device)
def predict_to_df(self, image_paths):
img_caption_pred = self.predict_step(image_paths)
img_cation_df = pd.DataFrame(list(zip(image_paths, img_caption_pred)))
img_cation_df.columns = ["img", "caption"]
return img_cation_df
#img_cation_df.to_html(escape=False, formatters=dict(Country=path_to_image_html))
def predict_step(self ,image_paths, max_length = 128, num_beams = 4):
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
images = []
for image_path in image_paths:
#i_image = Image.open(image_path)
if image_path.startswith("http"):
i_image = Image.open(
requests.get(image_path, stream=True).raw
)
else:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
output_ids = self.model.generate(pixel_values, **gen_kwargs)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def path_to_image_html(path):
return '<img src="'+ path + '" width="60" >'
i2c_tiny_zh_obj = Image2Caption("svjack/vit-gpt-diffusion-zh",
overwrite_encoder_checkpoint_path = "google/vit-base-patch16-224",
overwrite_token_model_path = "IDEA-CCNL/Wenzhong-GPT2-110M"
)
i2c_tiny_zh_obj.predict_step(
["https://datasets-server.huggingface.co/assets/poloclub/diffusiondb/--/2m_all/train/28/image/image.jpg"]
)
['"一个年轻男人的肖像,由Greg Rutkowski创作"。Artstation上的趋势"。"《刀锋战士》的艺术作品"。高度细节化。"电影般的灯光"。超现实主义。锐利的焦点。辛烷�']