Commit
·
9ae4463
1
Parent(s):
77bf3e6
Update kto deployment
Browse files- README.md +5 -5
- app.py +23 -13
- chat_interface_preference.py +106 -102
- requirements.txt +8 -6
README.md
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🦾💪🏽
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
suggested_hardware: t4-small
|
12 |
-
short_description: LLM, chatbot
|
13 |
---
|
14 |
|
15 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO)
|
3 |
emoji: 🦾💪🏽
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.39
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
suggested_hardware: t4-small
|
12 |
+
short_description: LLM, chatbot, human-feedback
|
13 |
---
|
14 |
|
15 |
+
Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
|
app.py
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
import os
|
3 |
-
|
|
|
4 |
from typing import Iterator
|
5 |
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
-
import torch
|
9 |
-
from transformers import
|
|
|
|
|
|
|
|
|
10 |
|
11 |
from chat_interface_preference import ChatInterface
|
12 |
|
@@ -18,7 +23,6 @@ if torch.cuda.is_available():
|
|
18 |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
19 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
20 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
21 |
-
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
22 |
|
23 |
|
24 |
@spaces.GPU
|
@@ -31,7 +35,8 @@ def generate(
|
|
31 |
top_k: int = 40,
|
32 |
repetition_penalty: float = 1.2,
|
33 |
) -> Iterator[str]:
|
34 |
-
|
|
|
35 |
for user, assistant in chat_history:
|
36 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
37 |
conversation.append({"role": "user", "content": message})
|
@@ -68,11 +73,8 @@ chat_interface = ChatInterface(
|
|
68 |
prefence_techniques="kto",
|
69 |
min_turns=1,
|
70 |
max_turns=10,
|
71 |
-
repo_id="llm-human-feedback-collector-chat-interface-
|
72 |
-
chatbot=gr.Chatbot(
|
73 |
-
height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True
|
74 |
-
),
|
75 |
-
css=style,
|
76 |
cache_examples=False,
|
77 |
additional_inputs=[
|
78 |
gr.Slider(
|
@@ -87,7 +89,7 @@ chat_interface = ChatInterface(
|
|
87 |
minimum=0.05,
|
88 |
maximum=1.2,
|
89 |
step=0.05,
|
90 |
-
value=0.
|
91 |
),
|
92 |
gr.Slider(
|
93 |
label="Top-p (nucleus sampling)",
|
@@ -117,8 +119,16 @@ chat_interface = ChatInterface(
|
|
117 |
["What are great things cook when getting started with Asian cooking?"],
|
118 |
["Who was Anthony Bourdain?"],
|
119 |
],
|
120 |
-
title="💪🏽🦾
|
121 |
-
description=""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
)
|
123 |
|
124 |
with gr.Blocks(css="style.css") as demo:
|
|
|
1 |
#!/usr/bin/env python
|
2 |
import os
|
3 |
+
import random
|
4 |
+
from threading import Thread # noqa
|
5 |
from typing import Iterator
|
6 |
|
7 |
import gradio as gr
|
8 |
import spaces
|
9 |
+
import torch # noqa
|
10 |
+
from transformers import (
|
11 |
+
AutoModelForCausalLM, # noqa
|
12 |
+
AutoTokenizer, # noqa
|
13 |
+
TextIteratorStreamer, # noqa
|
14 |
+
)
|
15 |
|
16 |
from chat_interface_preference import ChatInterface
|
17 |
|
|
|
23 |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
24 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
26 |
|
27 |
|
28 |
@spaces.GPU
|
|
|
35 |
top_k: int = 40,
|
36 |
repetition_penalty: float = 1.2,
|
37 |
) -> Iterator[str]:
|
38 |
+
system_message = random.choice(["concise", "explicit", "simple", "complex", "usefull", "helpfull"])
|
39 |
+
conversation = [{"role": "system", "content": f"Communicate {system_message}."}]
|
40 |
for user, assistant in chat_history:
|
41 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
42 |
conversation.append({"role": "user", "content": message})
|
|
|
73 |
prefence_techniques="kto",
|
74 |
min_turns=1,
|
75 |
max_turns=10,
|
76 |
+
repo_id="llm-human-feedback-collector-chat-interface-dpo",
|
77 |
+
chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
|
|
|
|
|
|
|
78 |
cache_examples=False,
|
79 |
additional_inputs=[
|
80 |
gr.Slider(
|
|
|
89 |
minimum=0.05,
|
90 |
maximum=1.2,
|
91 |
step=0.05,
|
92 |
+
value=0.7,
|
93 |
),
|
94 |
gr.Slider(
|
95 |
label="Top-p (nucleus sampling)",
|
|
|
119 |
["What are great things cook when getting started with Asian cooking?"],
|
120 |
["Who was Anthony Bourdain?"],
|
121 |
],
|
122 |
+
title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (KTO) 🦾💪🏽",
|
123 |
+
description="".join(
|
124 |
+
[
|
125 |
+
"This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface) which also uses the [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler) to allow for human feedback collection. ",
|
126 |
+
"Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
|
127 |
+
"This demo shows how you might capture human feedback directly from applications within Gradio. ",
|
128 |
+
"The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
|
129 |
+
"however, it might benefit from additional data curation with something like [Argilla](https://github.com/argilla-io/argilla/) for human feedback and/or [distilabel](https://github.com/argilla-io/distilabel/) for AI feedback. Argilla can even be [deployed for free on Hugging Face Spaces](https://argilla-io.github.io/argilla/latest/getting_started/huggingface-spaces/).",
|
130 |
+
]
|
131 |
+
),
|
132 |
)
|
133 |
|
134 |
with gr.Blocks(css="style.css") as demo:
|
chat_interface_preference.py
CHANGED
@@ -144,15 +144,15 @@ class ChatInterface(Blocks):
|
|
144 |
submit_btn_bad = None
|
145 |
stop_btn = "Stop"
|
146 |
undo_btn = "↩️ Undo"
|
147 |
-
clear_btn = "🗑️
|
148 |
if "kto" in prefence_techniques:
|
149 |
-
submit_btn_good = "
|
150 |
-
submit_btn_bad = "
|
151 |
if any([technique for technique in ["dpo", "simpo", "rlhf", "orpo"] if technique in self.prefence_techniques]):
|
152 |
-
submit_btn_two =
|
153 |
-
submit_btn_a = "
|
154 |
-
submit_btn_b = "
|
155 |
-
submit_btn_ab = "
|
156 |
super().__init__(
|
157 |
analytics_enabled=analytics_enabled,
|
158 |
mode="chat_interface",
|
@@ -219,14 +219,13 @@ class ChatInterface(Blocks):
|
|
219 |
with self:
|
220 |
if title:
|
221 |
Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>")
|
222 |
-
if description:
|
223 |
-
Markdown(description)
|
224 |
if self.commit_scheduler:
|
225 |
Markdown(
|
226 |
-
f
|
227 |
)
|
228 |
-
|
229 |
-
Markdown(
|
|
|
230 |
if chatbot:
|
231 |
self.chatbot = chatbot.render()
|
232 |
else:
|
@@ -387,13 +386,13 @@ class ChatInterface(Blocks):
|
|
387 |
|
388 |
def _setup_events(self) -> None:
|
389 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
390 |
-
submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=
|
391 |
submit_triggers_one = (
|
392 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
393 |
)
|
394 |
submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
|
395 |
if self.submit_btn_two:
|
396 |
-
submit_fn_two = functools.partial(submit_fn_one, n_generations=
|
397 |
submit_triggers_two = [self.submit_btn_two.click]
|
398 |
submit_tuples.append((submit_fn_two, submit_triggers_two))
|
399 |
for _fn, _triggers in submit_tuples:
|
@@ -608,7 +607,7 @@ class ChatInterface(Blocks):
|
|
608 |
if turn[-1]:
|
609 |
conversation += self._get_chat_message(turn[-1], role="user", turn=(idx + 1))
|
610 |
|
611 |
-
return
|
612 |
|
613 |
def _get_conversation_in_openai_format(self, history):
|
614 |
conversation = []
|
@@ -622,26 +621,22 @@ class ChatInterface(Blocks):
|
|
622 |
|
623 |
@staticmethod
|
624 |
def _get_chat_message(message, role, turn):
|
625 |
-
|
626 |
-
justify = "right"
|
627 |
-
else:
|
628 |
-
justify = "left"
|
629 |
return (
|
630 |
-
|
631 |
-
+
|
632 |
-
+ f"<strong>Turn {turn} - {role.capitalize()}:</strong><br>"
|
633 |
+ f"<em>Length: {len(message)} characters</em><br><br>"
|
634 |
+ f'<div class="message-identifier">{message}</div>'
|
635 |
-
+ "</div
|
636 |
)
|
637 |
|
638 |
def _get_chat_message_comparison(self, content_a, content_b):
|
639 |
return (
|
640 |
-
'<div class="container">'
|
641 |
-
+ '<div class="column">'
|
642 |
+ self._get_chat_message(message=content_a, role="system", turn="A")
|
643 |
+ "</div>"
|
644 |
-
+ '<div class="column">'
|
645 |
+ self._get_chat_message(message=content_b, role="system", turn="B")
|
646 |
+ "</div>"
|
647 |
+ "</div>"
|
@@ -688,30 +683,34 @@ class ChatInterface(Blocks):
|
|
688 |
|
689 |
self._check_message(message)
|
690 |
self._check_num_turns(history)
|
691 |
-
|
|
|
|
|
|
|
692 |
if self._check_if_two_responses(response):
|
693 |
-
|
|
|
|
|
|
|
694 |
|
695 |
-
|
|
|
|
|
|
|
|
|
|
|
696 |
|
697 |
-
|
698 |
-
|
699 |
-
response = await self.fn(*inputs)
|
700 |
else:
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
if n_generations == 1:
|
705 |
-
response = await _get_response()
|
706 |
-
else:
|
707 |
-
response_one, response_two = await _get_response(), await _get_response()
|
708 |
-
response = self._get_chat_message_comparison(response_one, response_two)
|
709 |
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
|
716 |
async def _stream_fn(
|
717 |
self,
|
@@ -728,67 +727,35 @@ class ChatInterface(Blocks):
|
|
728 |
history = history_with_input[:-1]
|
729 |
self._check_message(message)
|
730 |
self._check_num_turns(history)
|
731 |
-
_, response = history_with_input[-1]
|
732 |
-
if self._check_if_two_responses(response):
|
733 |
-
raise Error("Two options detected: undo, log or random pick continuation.")
|
734 |
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
if n_generations == 2:
|
745 |
-
first_response_formatted = self._get_chat_message_comparison(first_response, "")
|
746 |
-
else:
|
747 |
-
first_response_formatted = first_response
|
748 |
-
if self.multimodal and isinstance(message, dict):
|
749 |
-
for x in message["files"]:
|
750 |
-
history.append([(x,), None])
|
751 |
-
update = history + [[message["text"], first_response_formatted]]
|
752 |
-
yield update, update
|
753 |
-
else:
|
754 |
-
update = history + [[message, first_response_formatted]]
|
755 |
-
yield update, update
|
756 |
-
except StopIteration:
|
757 |
-
if self.multimodal and isinstance(message, dict):
|
758 |
-
self._append_multimodal_history(message, None, history)
|
759 |
-
yield history, history
|
760 |
-
else:
|
761 |
-
update = history + [[message, None]]
|
762 |
-
yield update, update
|
763 |
-
async for response in generator:
|
764 |
-
if n_generations == 2:
|
765 |
-
response_formatted = self._get_chat_message_comparison(response, "")
|
766 |
-
else:
|
767 |
-
response_formatted = response
|
768 |
-
if self.multimodal and isinstance(message, dict):
|
769 |
-
update = history + [[message["text"], response_formatted]]
|
770 |
-
yield update, update
|
771 |
-
else:
|
772 |
-
update = history + [[message, response_formatted]]
|
773 |
-
yield update, update
|
774 |
|
775 |
-
if n_generations == 2:
|
776 |
-
if self.is_async:
|
777 |
-
generator_two = self.fn(*inputs)
|
778 |
-
else:
|
779 |
-
generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
780 |
-
generator_two = SyncToAsyncIterator(generator_two, self.limiter)
|
781 |
try:
|
782 |
-
|
783 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
784 |
if self.multimodal and isinstance(message, dict):
|
785 |
for x in message["files"]:
|
786 |
history.append([(x,), None])
|
787 |
-
|
788 |
-
update = history + [[message["text"], first_response_two_formatted]]
|
789 |
yield update, update
|
790 |
else:
|
791 |
-
update = history + [[message,
|
792 |
yield update, update
|
793 |
except StopIteration:
|
794 |
if self.multimodal and isinstance(message, dict):
|
@@ -797,15 +764,52 @@ class ChatInterface(Blocks):
|
|
797 |
else:
|
798 |
update = history + [[message, None]]
|
799 |
yield update, update
|
800 |
-
async for
|
801 |
-
|
|
|
|
|
|
|
802 |
if self.multimodal and isinstance(message, dict):
|
803 |
-
update = history + [[message["text"],
|
804 |
yield update, update
|
805 |
else:
|
806 |
-
update = history + [[message,
|
807 |
yield update, update
|
808 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
809 |
async def _log_fn(
|
810 |
self, message: str | dict[str, list], history: list[list[str | tuple | None]], log: str
|
811 |
) -> tuple[
|
|
|
144 |
submit_btn_bad = None
|
145 |
stop_btn = "Stop"
|
146 |
undo_btn = "↩️ Undo"
|
147 |
+
clear_btn = "🗑️ Clear"
|
148 |
if "kto" in prefence_techniques:
|
149 |
+
submit_btn_good = "The response 👍"
|
150 |
+
submit_btn_bad = "The response 👎"
|
151 |
if any([technique for technique in ["dpo", "simpo", "rlhf", "orpo"] if technique in self.prefence_techniques]):
|
152 |
+
submit_btn_two = None
|
153 |
+
submit_btn_a = "A is better than B"
|
154 |
+
submit_btn_b = "B is better than A"
|
155 |
+
submit_btn_ab = "A and B are similar"
|
156 |
super().__init__(
|
157 |
analytics_enabled=analytics_enabled,
|
158 |
mode="chat_interface",
|
|
|
219 |
with self:
|
220 |
if title:
|
221 |
Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>")
|
|
|
|
|
222 |
if self.commit_scheduler:
|
223 |
Markdown(
|
224 |
+
f'<center><h2>Data is being logged to <a href="https://huggingface.co/datasets/{self.commit_scheduler.repo_id}">a dataset on the Hugging Face Hub</a></h2></center>'
|
225 |
)
|
226 |
+
if description:
|
227 |
+
Markdown(description)
|
228 |
+
|
229 |
if chatbot:
|
230 |
self.chatbot = chatbot.render()
|
231 |
else:
|
|
|
386 |
|
387 |
def _setup_events(self) -> None:
|
388 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
389 |
+
submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=2)
|
390 |
submit_triggers_one = (
|
391 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
392 |
)
|
393 |
submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
|
394 |
if self.submit_btn_two:
|
395 |
+
submit_fn_two = functools.partial(submit_fn_one, n_generations=1)
|
396 |
submit_triggers_two = [self.submit_btn_two.click]
|
397 |
submit_tuples.append((submit_fn_two, submit_triggers_two))
|
398 |
for _fn, _triggers in submit_tuples:
|
|
|
607 |
if turn[-1]:
|
608 |
conversation += self._get_chat_message(turn[-1], role="user", turn=(idx + 1))
|
609 |
|
610 |
+
return "<body>" + conversation + "</body>"
|
611 |
|
612 |
def _get_conversation_in_openai_format(self, history):
|
613 |
conversation = []
|
|
|
621 |
|
622 |
@staticmethod
|
623 |
def _get_chat_message(message, role, turn):
|
624 |
+
# return f"<p><div class='message-identifier'>{message}</div></p>"
|
|
|
|
|
|
|
625 |
return (
|
626 |
+
'<div class="message-content">'
|
627 |
+
+ f"<strong>Option {turn}</strong><br>"
|
|
|
628 |
+ f"<em>Length: {len(message)} characters</em><br><br>"
|
629 |
+ f'<div class="message-identifier">{message}</div>'
|
630 |
+
+ "</div>"
|
631 |
)
|
632 |
|
633 |
def _get_chat_message_comparison(self, content_a, content_b):
|
634 |
return (
|
635 |
+
'<div class="container" style="display: flex; width: 100%;">'
|
636 |
+
+ '<div class="column" style="flex: 1; padding: 10px;">'
|
637 |
+ self._get_chat_message(message=content_a, role="system", turn="A")
|
638 |
+ "</div>"
|
639 |
+
+ '<div class="column" style="flex: 1; padding: 10px;">'
|
640 |
+ self._get_chat_message(message=content_b, role="system", turn="B")
|
641 |
+ "</div>"
|
642 |
+ "</div>"
|
|
|
683 |
|
684 |
self._check_message(message)
|
685 |
self._check_num_turns(history)
|
686 |
+
if history:
|
687 |
+
_, response = history[-1]
|
688 |
+
else:
|
689 |
+
response = None
|
690 |
if self._check_if_two_responses(response):
|
691 |
+
Info("Two options detected: provide preference, undo or clear to continue conversation.")
|
692 |
+
return history, history
|
693 |
+
else:
|
694 |
+
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
695 |
|
696 |
+
async def _get_response():
|
697 |
+
if self.is_async:
|
698 |
+
response = await self.fn(*inputs)
|
699 |
+
else:
|
700 |
+
response = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
701 |
+
return response
|
702 |
|
703 |
+
if n_generations == 1:
|
704 |
+
response = await _get_response()
|
|
|
705 |
else:
|
706 |
+
response_one, response_two = await _get_response(), await _get_response()
|
707 |
+
response = self._get_chat_message_comparison(response_one, response_two)
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
|
709 |
+
if self.multimodal and isinstance(message, dict):
|
710 |
+
self._append_multimodal_history(message, response, history)
|
711 |
+
elif isinstance(message, str):
|
712 |
+
history.append([message, response])
|
713 |
+
return history, history
|
714 |
|
715 |
async def _stream_fn(
|
716 |
self,
|
|
|
727 |
history = history_with_input[:-1]
|
728 |
self._check_message(message)
|
729 |
self._check_num_turns(history)
|
|
|
|
|
|
|
730 |
|
731 |
+
if history:
|
732 |
+
_, response = history[-1]
|
733 |
+
else:
|
734 |
+
response = None
|
735 |
+
if self._check_if_two_responses(response):
|
736 |
+
Info("Two options detected: provide preference, undo or clear to continue conversation.")
|
737 |
+
yield history, history
|
738 |
+
else:
|
739 |
+
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
740 |
|
|
|
|
|
|
|
|
|
|
|
|
|
741 |
try:
|
742 |
+
if self.is_async:
|
743 |
+
generator = self.fn(*inputs)
|
744 |
+
else:
|
745 |
+
generator = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
746 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
747 |
+
first_response = await async_iteration(generator)
|
748 |
+
if n_generations == 2:
|
749 |
+
first_response_formatted = self._get_chat_message_comparison(first_response, "")
|
750 |
+
else:
|
751 |
+
first_response_formatted = first_response
|
752 |
if self.multimodal and isinstance(message, dict):
|
753 |
for x in message["files"]:
|
754 |
history.append([(x,), None])
|
755 |
+
update = history + [[message["text"], first_response_formatted]]
|
|
|
756 |
yield update, update
|
757 |
else:
|
758 |
+
update = history + [[message, first_response_formatted]]
|
759 |
yield update, update
|
760 |
except StopIteration:
|
761 |
if self.multimodal and isinstance(message, dict):
|
|
|
764 |
else:
|
765 |
update = history + [[message, None]]
|
766 |
yield update, update
|
767 |
+
async for response in generator:
|
768 |
+
if n_generations == 2:
|
769 |
+
response_formatted = self._get_chat_message_comparison(response, "")
|
770 |
+
else:
|
771 |
+
response_formatted = response
|
772 |
if self.multimodal and isinstance(message, dict):
|
773 |
+
update = history + [[message["text"], response_formatted]]
|
774 |
yield update, update
|
775 |
else:
|
776 |
+
update = history + [[message, response_formatted]]
|
777 |
yield update, update
|
778 |
|
779 |
+
if n_generations == 2:
|
780 |
+
if self.is_async:
|
781 |
+
generator_two = self.fn(*inputs)
|
782 |
+
else:
|
783 |
+
generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
784 |
+
generator_two = SyncToAsyncIterator(generator_two, self.limiter)
|
785 |
+
try:
|
786 |
+
first_response_two = await async_iteration(generator_two)
|
787 |
+
first_response_two_formatted = self._get_chat_message_comparison(response, first_response_two)
|
788 |
+
if self.multimodal and isinstance(message, dict):
|
789 |
+
for x in message["files"]:
|
790 |
+
history.append([(x,), None])
|
791 |
+
|
792 |
+
update = history + [[message["text"], first_response_two_formatted]]
|
793 |
+
yield update, update
|
794 |
+
else:
|
795 |
+
update = history + [[message, first_response_two_formatted]]
|
796 |
+
yield update, update
|
797 |
+
except StopIteration:
|
798 |
+
if self.multimodal and isinstance(message, dict):
|
799 |
+
self._append_multimodal_history(message, None, history)
|
800 |
+
yield history, history
|
801 |
+
else:
|
802 |
+
update = history + [[message, None]]
|
803 |
+
yield update, update
|
804 |
+
async for response_two in generator_two:
|
805 |
+
response_two = self._get_chat_message_comparison(response, response_two)
|
806 |
+
if self.multimodal and isinstance(message, dict):
|
807 |
+
update = history + [[message["text"], response_two]]
|
808 |
+
yield update, update
|
809 |
+
else:
|
810 |
+
update = history + [[message, response_two]]
|
811 |
+
yield update, update
|
812 |
+
|
813 |
async def _log_fn(
|
814 |
self, message: str | dict[str, list], history: list[list[str | tuple | None]], log: str
|
815 |
) -> tuple[
|
requirements.txt
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
-
|
2 |
-
bitsandbytes==0.42
|
3 |
-
gradio==4.36.1
|
4 |
scipy==1.13.0
|
5 |
-
sentencepiece==0.2.0
|
6 |
spaces==0.28.3
|
7 |
-
torch
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.39
|
|
|
|
|
2 |
scipy==1.13.0
|
|
|
3 |
spaces==0.28.3
|
4 |
+
torch
|
5 |
+
accelerate
|
6 |
+
bitsandbytes
|
7 |
+
torch
|
8 |
+
transformers>=4.43.2
|
9 |
+
einops
|
10 |
+
sentencepiece
|