File size: 3,081 Bytes
80e8e8d
 
42f87c6
f55a67c
 
 
 
80e8e8d
f55a67c
 
80e8e8d
 
f55a67c
 
 
 
 
 
 
 
 
 
 
42f87c6
f55a67c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80e8e8d
f55a67c
bc7d8a5
f55a67c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42f87c6
 
a303d6f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from dotenv import load_dotenv
import gradio as gr
from tools import create_agent
from langchain_core.messages import RemoveMessage
from langchain_core.messages import trim_messages

# Global params
AGENT = create_agent()
theme = gr.themes.Default(primary_hue="red", secondary_hue="red")
default_msg = "Bonjour ! Je suis là pour répondre à vos questions sur l'actuariat. Comment puis-je vous aider aujourd'hui ?"


def filter_msg(msg_list:list, keep_n:int) -> list:
    """Keep only last keep_n messages from chat history. Preserves structure user msg -> tool msg -> ai msg"""
    msg = trim_messages(
            msg_list,
            strategy="last",
            token_counter=len,
            max_tokens=keep_n,
            start_on="human",
            end_on=("tool", "ai"),
            include_system=True,
    )
    return [m.id for m in msg]

def agent_response(query, config, keep_n=10):
    messages = AGENT.get_state(config).values.get("messages", [])
    
    if len(messages) > keep_n:
        keep_msg_ids = filter_msg(messages, keep_n)
        AGENT.update_state(config, {"messages": [RemoveMessage(id=m.id) for m in messages if m.id not in keep_msg_ids]})
        print("msg removed")

    # Generate answer
    answer = AGENT.invoke({"messages":query}, config=config)
    return answer["messages"][-1].content


js_func = """
function refresh() {
    const url = new URL(window.location);

    if (url.searchParams.get('__theme') != 'light') {
        url.searchParams.set('__theme', 'light');
        window.location.href = url.href;
    }
}
"""


def delete_agent():
    print("del agent")
    global AGENT
    AGENT = create_agent()
    # print(AGENT.get_state(config).values.get("messages"), "\n\n")

with gr.Blocks(theme=theme, js=js_func, title="Dataltist", fill_height=True) as iface:
    gr.Markdown("# Dataltist Chatbot 🚀")
    chatbot = gr.Chatbot(show_copy_button=True, show_share_button=False, value=[{"role":"assistant", "content":default_msg}], type="messages", scale=1)
    msg = gr.Textbox(lines=1, show_label=False, placeholder="Posez vos questions sur l'assurance")  #  submit_btn=True
    # clear = gr.ClearButton([msg, chatbot], value="Effacer 🗑")
    config = {"configurable": {"thread_id": "1"}} 


    def user(user_message, history: list):
        return "", history + [{"role": "user", "content": user_message}]
    
    def bot(history: list):
        bot_message = agent_response(history[-1]["content"], config)  #AGENT.invoke({"messages":history[-1]["content"]}, config=config)
        history.append({"role": "assistant", "content": ""})
        for character in bot_message:
            history[-1]['content'] += character
            # time.sleep(0.005)
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    iface.unload(delete_agent)

if __name__ == "__main__":
    # load_dotenv()
    # AUTH_ID = os.environ.get("AUTH_ID")
    # AUTH_PASS = os.environ.get("AUTH_PASS")
    iface.launch()  #share=True, auth=(AUTH_ID, AUTH_PASS)