import gradio import pandas as pd from matplotlib import pyplot as plt from config import CONFIG from data import get_extra_tokens, BenetechOutput, ChartType from model import predict_string, build_model def gradio_visualize_prediction(string): string = string.removeprefix(get_extra_tokens().benetech_prompt) if not BenetechOutput.does_string_match_expected_pattern(string): return benetech_output = BenetechOutput.from_string(string) x = benetech_output.x_data[: len(benetech_output.y_data)] y = benetech_output.y_data[: len(benetech_output.x_data)] df = pd.DataFrame(dict(x=x, y=y)) plt_plot = { ChartType.line: plt.plot, ChartType.scatter: plt.scatter, ChartType.horizontal_bar: plt.barh, ChartType.vertical_bar: plt.bar, ChartType.dot: plt.scatter, } plt_plot[benetech_output.chart_type](x, y) plt.xticks(rotation=30) plt.savefig("plot.png") ... config = CONFIG config.pretrained_model_name = "checkpoint" model = build_model(config) interface = gradio.Interface( title="Making graphs accessible", description="Generate textual representation of a graph\n" "https://www.kaggle.com/competitions/benetech-making-graphs-accessible", fn=lambda image: predict_string(image, model), inputs="image", outputs="text", examples="examples", ) interface.launch()