Spaces:
Sleeping
Sleeping
import transformers | |
import re | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
from vllm import LLM, SamplingParams | |
import torch | |
import gradio as gr | |
import json | |
import os | |
import shutil | |
import requests | |
import chromadb | |
import difflib | |
import pandas as pd | |
from chromadb.config import Settings | |
from chromadb.utils import embedding_functions | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_name = "Pclanglais/ocronos2" | |
llm = LLM(model_name, max_model_len=8128) | |
#CSS for references formatting | |
css = """ | |
.generation { | |
margin-left:2em; | |
margin-right:2em; | |
size:1.2em; | |
} | |
:target { | |
background-color: #CCF3DF; | |
} | |
.source { | |
float:left; | |
max-width:17%; | |
margin-left:2%; | |
} | |
.tooltip { | |
position: relative; | |
cursor: pointer; | |
font-variant-position: super; | |
color: #97999b; | |
} | |
.tooltip:hover::after { | |
content: attr(data-text); | |
position: absolute; | |
left: 0; | |
top: 120%; | |
white-space: pre-wrap; | |
width: 500px; | |
max-width: 500px; | |
z-index: 1; | |
background-color: #f9f9f9; | |
color: #000; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 5px; | |
display: block; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
} | |
/* New styles for diff */ | |
.deleted { | |
background-color: #ffcccb; | |
text-decoration: line-through; | |
} | |
.inserted { | |
background-color: #90EE90; | |
} | |
""" | |
#Curtesy of claude | |
def generate_html_diff(old_text, new_text): | |
d = difflib.Differ() | |
diff = list(d.compare(old_text.split(), new_text.split())) | |
html_diff = [] | |
for word in diff: | |
if word.startswith(' '): | |
html_diff.append(word[2:]) | |
elif word.startswith('+ '): | |
html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>') | |
# We're not adding anything for words that start with '- ' | |
return ' '.join(html_diff) | |
# Class to encapsulate the Falcon chatbot | |
class MistralChatBot: | |
def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
self.system_prompt = system_prompt | |
def predict(self, user_message): | |
sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"]) | |
detailed_prompt = f"### TEXT ###\n{user_message}\n\n### CORRECTION ###\n" | |
print(detailed_prompt) | |
prompts = [detailed_prompt] | |
outputs = llm.generate(prompts, sampling_params, use_tqdm=False) | |
generated_text = outputs[0].outputs[0].text | |
# Generate HTML diff | |
html_diff = generate_html_diff(user_message, generated_text) | |
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + html_diff + "</div>" | |
return generated_text | |
# Create the Falcon chatbot instance | |
mistral_bot = MistralChatBot() | |
# Define the Gradio interface | |
title = "Correction d'OCR" | |
description = "Un outil expérimental de correction d'OCR basé sur des modèles de langue" | |
examples = [ | |
[ | |
"Qui peut bénéficier de l'AIP?", # user_message | |
0.7 # temperature | |
] | |
] | |
additional_inputs=[ | |
gr.Slider( | |
label="Température", | |
value=0.2, # Default value | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté", | |
), | |
] | |
demo = gr.Blocks() | |
with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo: | |
gr.HTML("""<h1 style="text-align:center">Correction d'OCR</h1>""") | |
text_input = gr.Textbox(label="Votre texte.", type="text", lines=1) | |
text_button = gr.Button("Corriger l'OCR") | |
text_output = gr.HTML(label="Le texte corrigé") | |
text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output]) | |
if __name__ == "__main__": | |
demo.queue().launch() |