File size: 6,376 Bytes
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b2f713
 
 
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b2f713
 
 
 
 
3cad23b
 
7b2f713
 
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import dotenv
dotenv.load_dotenv()

import json
import os
import random
import threading
import time

from toolformers.base import Tool, parameter_from_openai_api, StringParameter
from toolformers.base import Toolformer
from toolformers.camel import make_openai_toolformer
from toolformers.langchain_agent import LangChainAnthropicToolformer
from toolformers.sambanova import SambanovaToolformer
from toolformers.gemini import GeminiToolformer

from querier import Querier
from responder import Responder
from negotiator import SenderNegotiator, ReceiverNegotiator
from programmer import SenderProgrammer, ReceiverProgrammer
from executor import UnsafeExecutor
from utils import compute_hash


def create_toolformer(model_name) -> Toolformer:
    if model_name in ['gpt-4o', 'gpt-4o-mini']:
        return make_openai_toolformer(model_name)
    elif 'claude' in model_name:
        return LangChainAnthropicToolformer(model_name, os.environ.get('ANTHROPIC_API_KEY'))
    elif model_name in ['llama3-405b']:
        return SambanovaToolformer(model_name)
    elif model_name in ['gemini-1.5-pro']:
        return GeminiToolformer(model_name)
    else:
        raise ValueError(f"Unknown model name: {model_name}")

def full_flow(schema, alice_model, bob_model):
    NL_MESSAGES = []
    NEGOTIATION_MESSAGES = []
    STRUCTURED_MESSAGES = []
    ARTIFACTS = {}

    toolformer_alice = create_toolformer(alice_model)
    toolformer_bob = create_toolformer(bob_model)

    querier = Querier(toolformer_alice)
    responder = Responder(toolformer_bob)

    tools = []

    for tool_schema in schema['tools']:
        parameters = [parameter_from_openai_api(name, schema, name in tool_schema['input']['required']) for name, schema in tool_schema['input']['properties'].items()]
        
        def tool_fn(*args, **kwargs):
            print(f'Bob tool {tool_schema["name"]} called with args {args} and kwargs {kwargs}')
            return random.choice(tool_schema['dummy_outputs'])
        
        tool = Tool(tool_schema['name'], tool_schema['description'], parameters, tool_fn, tool_schema['output'])
        tools.append(tool)

    def nl_callback_fn(query):
        print(query)
        NL_MESSAGES.append({
            'role': 'assistant',
            #'content': query['body'],
            'body': query['body'],
            'protocolHash': None
        })

        response = responder.reply_to_query(query['body'], query['protocolHash'], tools, '')

        NL_MESSAGES.append({
            'role': 'user',
            #'content': response['body']
            'status': 'success',
            'body': response['body']
        })

        return response
    
    negotiator_sender = SenderNegotiator(toolformer_alice)
    negotiator_receiver = ReceiverNegotiator(toolformer_bob, tools, '')

    def negotiation_callback_fn(query):
        print(query)
        NEGOTIATION_MESSAGES.append({
            'role': 'assistant',
            'content': query
        })

        response = negotiator_receiver.handle_negotiation(query)

        NEGOTIATION_MESSAGES.append({
            'role': 'user',
            'content': response
        })

        #print('CURRENT NEGOTIATION MESSAGES:', len(NEGOTIATION_MESSAGES))

        return response
    
    def final_message_callback_fn(query):
        NEGOTIATION_MESSAGES.append({
            'role': 'assistant',
            'content': query
        })

    sender_programmer = SenderProgrammer(toolformer_alice)
    receiver_programmer = ReceiverProgrammer(toolformer_bob)

    executor = UnsafeExecutor()

    def structured_callback_fn(query):
        STRUCTURED_MESSAGES.append({
            'role': 'assistant',
            #'content': query
            'body': json.dumps(query) if isinstance(query, dict) else query,
            'protocolHash': ARTIFACTS['protocol']['hash'],
            'protocolSources': ['https://...']
        })

        try:
            response = executor.run_routine(ARTIFACTS['protocol']['hash'], ARTIFACTS['implementation_receiver'], query, tools)
        except Exception as e:
            import traceback
            traceback.print_exc()

            STRUCTURED_MESSAGES.append({
                'role': 'user',
                'status': 'error',
                'message': str(e)
            })
            return 'Error'

        STRUCTURED_MESSAGES.append({
            'role': 'user',
            #'content': response
            'status': 'success',
            'body': json.dumps(response) if isinstance(response, dict) else response
        })

        return response

    def flow():
        task_data = random.choice(schema['examples'])
        querier.send_query_without_protocol(schema, task_data, nl_callback_fn)

        #time.sleep(1)

        res = negotiator_sender.negotiate_protocol_for_task(schema, negotiation_callback_fn, final_message_callback_fn=final_message_callback_fn)
        protocol_hash = compute_hash(res['protocol'])
        res['hash'] = protocol_hash
        ARTIFACTS['protocol'] = res

        protocol_document = res['protocol']

        implementation_sender = sender_programmer.write_routine_for_task(schema, protocol_document)
        
        ARTIFACTS['implementation_sender'] = implementation_sender

        implementation_receiver = receiver_programmer.write_routine_for_tools(tools, protocol_document, '')

        ARTIFACTS['implementation_receiver'] = implementation_receiver
        send_tool = Tool('send_to_server', 'Send to server', StringParameter('query', 'The query', True), structured_callback_fn)

        try:
            executor.run_routine(protocol_hash, implementation_sender, task_data, [send_tool])
        except Exception as e:
            # Print the error

            import traceback
            traceback.print_exc()

            STRUCTURED_MESSAGES.append({
                'role': 'assistant',
                'status': 'error',
                'message': str(e)
            })

    def get_info():
        return NL_MESSAGES, NEGOTIATION_MESSAGES, STRUCTURED_MESSAGES, ARTIFACTS.get('protocol', {}).get('protocol', ''), \
            ARTIFACTS.get('implementation_sender', ''), ARTIFACTS.get('implementation_receiver', '')

    thread = threading.Thread(
        target = lambda: flow()
    )
    thread.start()
    while thread.is_alive():
        yield get_info()
        time.sleep(0.2)
    yield get_info()