davidberenstein1957's picture
Update app.py
40ef6c4 verified
import os
import io
import random
import requests
from PIL import Image
from dataset_viber import AnnotatorInterFace
HF_TOKEN = os.environ["HF_TOKEN"]
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
DATASET_SERVER_URL = "https://datasets-server.huggingface.co"
DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train"
MODEL_URL = (
"https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
)
MODEL_URLS = [MODEL_URL]
def retrieve_sample(idx):
api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1"
response = requests.get(api_url, headers=HEADERS)
data = response.json()
img_url = data["rows"][0]["row"]["image"]["src"]
prompt = data["rows"][0]["row"]["prompt"]
return img_url, prompt
def get_rows():
api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}"
response = requests.get(api_url, headers=HEADERS)
num_rows = response.json()["size"]["config"]["num_rows"]
return num_rows
def generate_response(prompt):
payload = {
"inputs": prompt,
}
response = requests.post(random.choice(MODEL_URLS), headers=HEADERS, json=payload)
image = Image.open(io.BytesIO(response.content))
return image
def next_input(_prompt, _completion_a, _completion_b):
random_idx = random.randint(0, get_rows()) - 1
img_url, prompt = retrieve_sample(random_idx)
return (prompt, generate_response(prompt), generate_response(prompt+" "))
if __name__ == "__main__":
interface = AnnotatorInterFace.for_image_generation_preference(
interactive=False, fn_next_input=next_input, dataset_name="dataset-viber-image-generation-preference-inference-endpoints-battle-flux"
)
interface.launch()