Spaces:
Runtime error
Runtime error
import gradio | |
import pandas as pd | |
from matplotlib import pyplot as plt | |
from config import CONFIG | |
from data import get_extra_tokens, BenetechOutput, ChartType | |
from model import predict_string, build_model | |
def gradio_visualize_prediction(string): | |
string = string.removeprefix(get_extra_tokens().benetech_prompt) | |
if not BenetechOutput.does_string_match_expected_pattern(string): | |
return | |
benetech_output = BenetechOutput.from_string(string) | |
x = benetech_output.x_data[: len(benetech_output.y_data)] | |
y = benetech_output.y_data[: len(benetech_output.x_data)] | |
df = pd.DataFrame(dict(x=x, y=y)) | |
plt_plot = { | |
ChartType.line: plt.plot, | |
ChartType.scatter: plt.scatter, | |
ChartType.horizontal_bar: plt.barh, | |
ChartType.vertical_bar: plt.bar, | |
ChartType.dot: plt.scatter, | |
} | |
plt_plot[benetech_output.chart_type](x, y) | |
plt.xticks(rotation=30) | |
plt.savefig("plot.png") | |
... | |
config = CONFIG | |
config.pretrained_model_name = "checkpoint" | |
model = build_model(config) | |
interface = gradio.Interface( | |
title="Making graphs accessible", | |
description="Generate textual representation of a graph\n" | |
"https://www.kaggle.com/competitions/benetech-making-graphs-accessible", | |
fn=lambda image: predict_string(image, model), | |
inputs="image", | |
outputs="text", | |
examples="examples", | |
) | |
interface.launch() | |