Spaces:
Running
on
T4
Running
on
T4
""" | |
Run qwen 7b chat. | |
transformers 4.31.0 | |
import torch | |
torch.cuda.empty_cache() | |
model.chat( | |
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, | |
query: str, | |
history: Optional[List[Tuple[str, str]]], | |
system: str = 'You are a helpful assistant.', | |
append_history: bool = True, | |
stream: Optional[bool] = <object object at 0x7f905797ec20>, | |
stop_words_ids: Optional[List[List[int]]] = None, | |
**kwargs) -> Tuple[str, List[Tuple[str, str]]] | |
) | |
model.generation_config | |
GenerationConfig { | |
"chat_format": "chatml", | |
"do_sample": true, | |
"eos_token_id": 151643, | |
"max_new_tokens": 512, | |
"max_window_size": 6144, | |
"pad_token_id": 151643, | |
"top_k": 0, | |
"top_p": 0.5, | |
"transformers_version": "4.31.0", | |
"trust_remote_code": true | |
} | |
""" | |
# pylint: disable=line-too-long, invalid-name, no-member, redefined-outer-name, missing-function-docstring, missing-class-docstring, broad-except, | |
from run_cmd import run_cmd # noqa | |
# autodl with cuda12 NVIDIA-SMI appears | |
# 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 | |
# no fix needed | |
# clumsy fix for hf overwrite libbitsandbytes_cpu.so with libbitsandbytes_cuda118.so | |
run_cmd( | |
"cd /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/bitsandbytes; cp libbitsandbytes_cuda118.so libbitsandbytes_cpu.so" | |
) # noqa | |
import gc | |
import os | |
import subprocess as sp | |
import sys | |
import time | |
from collections import deque | |
from dataclasses import asdict, dataclass | |
from textwrap import dedent | |
from types import SimpleNamespace | |
from typing import List, Optional | |
import gradio as gr | |
import rich | |
import torch | |
from loguru import logger | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation import GenerationConfig | |
from example_list import css, example_list | |
os.environ["TZ"] = "Asia/Shanghai" | |
try: | |
time.tzset() # type: ignore # pylint: disable=no-member | |
except Exception: | |
# Windows | |
logger.warning("Windows, cant run time.tzset()") | |
if True: | |
run_cmd( | |
"ls -rtl /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/bitsandbytes" | |
) | |
logger.info("lsb_release -a") | |
ret = sp.run("lsb_release -a", capture_output=1, check=0, shell=1, encoding="utf8") | |
if ret.stdout: | |
rich.print(ret.stdout) | |
if ret.stderr: | |
rich.print("[red bold]" + ret.stdout) | |
logger.info("nvidia-smi") | |
ret = sp.run("nvidia-smi", capture_output=1, check=0, shell=1, encoding="utf8") | |
if ret.stdout: | |
rich.print(ret.stdout) | |
if ret.stderr: | |
rich.print("[red bold]" + ret.stdout) | |
# raise SystemExit("Interrupt by intentioin") | |
if not torch.cuda.is_available(): | |
raise gr.Error("torch.cuda.is_available() is False, cant continue...") | |
model_name = "tangger/Qwen-7B-Chat" # try | |
model_name = "Qwen/Qwen-7B-Chat" # gone! | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
n_gpus = torch.cuda.device_count() | |
try: | |
_ = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" | |
except AssertionError: | |
_ = 0 | |
max_memory = {i: _ for i in range(n_gpus)} | |
del sys | |
# logger.remove() # to turn on trace | |
# logger.add(sys.stderr, level="TRACE") | |
# logger.trace(f"{chat_history=}") | |
def gen_model(model_name: str): | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
# device_map="auto", | |
device_map={"": 0}, | |
# load_in_4bit=True, | |
load_in_8bit=True, | |
max_memory=max_memory, | |
fp16=True, | |
torch_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
model = model.eval() | |
model.generation_config = GenerationConfig.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
) | |
return model | |
def user_clear(message, chat_history): | |
"""Gen a response, clear message in user textbox.""" | |
logger.debug(f"{message=}") | |
try: | |
chat_history.append([message, ""]) | |
except Exception: | |
chat_history = deque([message, ""], maxlen=5) | |
logger.trace(f"{chat_history=}") | |
return "", chat_history | |
def user(message, chat_history): | |
"""Gen a response.""" | |
logger.debug(f"{message=}") | |
logger.trace(f"{chat_history=}") | |
try: | |
chat_history.append([message, ""]) | |
except Exception: | |
chat_history = deque([message, ""], maxlen=5) | |
return message, chat_history | |
# for rerun in tests | |
model = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
if not torch.cuda.is_available(): | |
# raise gr.Error("GPU not available, cant run. Turn on GPU and retry") | |
raise SystemExit("GPU not available, cant run. Turn on GPU and retry") | |
model = gen_model(model_name) | |
def bot(chat_history, **kwargs): | |
try: | |
message = chat_history[-1][0] | |
except Exception as exc: | |
logger.error(f"{chat_history=}: {exc}") | |
return chat_history | |
logger.debug(f"{chat_history=}") | |
try: | |
_ = """ | |
response, chat_history = model.chat( | |
tokenizer, | |
message, | |
history=chat_history, | |
temperature=0.7, | |
repetition_penalty=1.2, | |
# max_length=128, | |
) | |
""" | |
logger.debug("run model.chat...") | |
model.generation_config.update(**kwargs) | |
response, chat_history = model.chat( | |
tokenizer, | |
message, | |
chat_history[:-1], | |
# **kwargs, | |
) | |
del response | |
return chat_history | |
except Exception as exc: | |
logger.error(exc) | |
chat_history[:-1].append(["message", str(exc)]) | |
return chat_history | |
def bot_stream(chat_history, **kwargs): | |
logger.trace(f"{kwargs=}") | |
# somehow, empty chat_history | |
if chat_history is None or not chat_history: | |
logger.trace(f" *** {chat_history=}") | |
chat_history.append(["", ""]) | |
try: | |
message = chat_history[-1][0] | |
except Exception as exc: | |
logger.error(f"{chat_history=}: {exc}") | |
raise gr.Error(f"{chat_history=}") | |
# yield chat_history | |
# for elm in model.chat_stream(tokenizer, message, chat_history): | |
model.generation_config.update(**kwargs) | |
response = "" | |
for elm in model.chat_stream(tokenizer, message, chat_history): | |
chat_history[-1] = [message, elm] | |
response = elm | |
yield chat_history | |
logger.debug(f"{response=}") | |
logger.debug(f"{model.generation_config=}") | |
SYSTEM_PROMPT = "You are a helpful assistant." | |
MAX_MAX_NEW_TOKENS = 2048 # sequence length 2048 | |
MAX_NEW_TOKENS = 256 | |
class Config: | |
max_new_tokens: int = MAX_NEW_TOKENS | |
repetition_penalty: float = 1.1 | |
temperature: float = 1.0 | |
top_k: int = 0 | |
top_p: float = 0.9 | |
# stats_default = SimpleNamespace(llm=model, system_prompt=SYSTEM_PROMPT, config=Config()) | |
stats_default = SimpleNamespace(llm=None, system_prompt=SYSTEM_PROMPT, config=Config()) | |
# input max_new_tokens temperature repetition_penalty top_k top_p system_prompt history | |
def api_fn( # pylint: disable=too-many-arguments | |
input_text: Optional[str], | |
# max_length: int = 256, | |
max_new_tokens: int = stats_default.config.max_new_tokens, | |
temperature: float = stats_default.config.temperature, | |
repetition_penalty: float = stats_default.config.repetition_penalty, | |
top_k: int = stats_default.config.top_k, | |
top_p: int = stats_default.config.top_p, | |
system_prompt: Optional[str] = None, | |
history: Optional[List[str]] = None, | |
): | |
if input_text is None: | |
input_text = "" | |
try: | |
input_text = str(input_text).strip() | |
except Exception as exc: | |
logger.error(exc) | |
input_text = "" | |
if not input_text: | |
return "" | |
if history is None: | |
history = [] | |
try: | |
temperature = float(temperature) | |
except Exception: | |
temperature = stats_default.config.temperature | |
if system_prompt is None: | |
system_prompt = stats_default.system_prompt | |
# if max_length < 10: max_length = 4096 | |
if max_new_tokens < 10: | |
max_new_tokens = stats_default.config.max_new_tokens | |
if top_p < 0.1 or top_p > 1: | |
top_p = stats_default.config.top_p | |
if temperature <= 0.5: | |
temperature = stats_default.config.temperature | |
_ = { | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
"top_k": top_k, | |
"top_p": top_p, | |
} | |
model.generation_config.update(**_) | |
try: | |
res, _ = model.chat( | |
tokenizer, | |
input_text, | |
history=history, | |
# max_length=max_length, | |
# append_history=False, | |
) | |
# logger.debug(f"{res=} \n{_=}") | |
except Exception as exc: | |
logger.error(f"{exc=}") | |
res = str(exc) | |
logger.debug(f"api {res=}") | |
logger.debug(f"api {model.generation_config=}") | |
return res | |
theme = gr.themes.Soft(text_size="sm") | |
with gr.Blocks( | |
theme=theme, | |
title=model_name.lower(), | |
css=css, | |
) as block: | |
stats = gr.State(stats_default) | |
# would this reset model? | |
model.generation_config = GenerationConfig.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
) | |
config = asdict(stats.value.config) | |
def bot_stream_state(chat_history): | |
logger.trace(f"{chat_history=}") | |
yield from bot_stream(chat_history, **config) | |
with gr.Accordion("🎈 Info", open=False): | |
gr.Markdown( | |
dedent( | |
f""" | |
## {model_name.lower()} | |
* temperature range: .51 and up; higher temperature implies more randomness. Suggested temperature for chatting and creative writing is around 1.1 while it should be set to 0.51-1.0 for summarizing and translation. | |
* Set `repetition_penalty` to 2.1 or higher for a chatty conversation (more unpredictable and undesirable output). Lower it to 1.1 or smaller if more focused anwsers are desired (for example for translations or fact-oriented queries). | |
* Smaller `top_k` probably will result in smoothier sentences. | |
(`top_k=0` is equivalent to `top_k` equal to very very big though.) Consult `transformers` documentation for more details. | |
* An API is available at https://mikeee-qwen-7b-chat.hf.space/ that can be queried, e.g., in python | |
```python | |
from gradio_client import Client | |
client = Client("https://mikeee-qwen-7b-chat.hf.space/") | |
result = client.predict( | |
"你好!", # user prompt | |
256, # max_new_tokens | |
1.2, # temperature | |
1.1, # repetition_penalty | |
0, # top_k | |
0.9, # top_p | |
"You are a helpful assistant.", # system_prompt | |
None, # history | |
api_name="/api" | |
) | |
print(result) | |
``` | |
or in javascript | |
```js | |
import {{ client }} from "@gradio/client"; | |
const app = await client("https://mikeee-qwen-7b-chat.hf.space/"); | |
const result = await app.predict("api", [...]); | |
console.log(result.data); | |
``` | |
Check documentation and examples by clicking `Use via API` at the very bottom of [https://huggingface.co./spaces/mikeee/qwen-7b-chat](https://huggingface.co./spaces/mikeee/qwen-7b-chat). | |
<p></p> | |
Most examples are meant for another model. | |
You probably should try to test | |
some related prompts. System prompt can be changed in Advaned Options as well.""" | |
), | |
elem_classes="xsmall", | |
) | |
chatbot = gr.Chatbot(height=500, value=deque([], maxlen=5)) # type: ignore | |
with gr.Row(): | |
with gr.Column(scale=5): | |
msg = gr.Textbox( | |
label="Chat Message Box", | |
placeholder="Ask me anything (press Shift+Enter or click Submit to send)", | |
show_label=False, | |
# container=False, | |
lines=4, | |
max_lines=30, | |
show_copy_button=True, | |
# ).style(container=False) | |
) | |
with gr.Column(scale=1, min_width=50): | |
with gr.Row(): | |
submit = gr.Button("Submit", elem_classes="xsmall") | |
stop = gr.Button("Stop", visible=True) | |
clear = gr.Button("Clear History", visible=True) | |
msg_submit_event = msg.submit( | |
# fn=conversation.user_turn, | |
fn=user, | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=True, | |
show_progress="full", | |
# api_name=None, | |
).then(bot_stream_state, chatbot, chatbot, queue=True) | |
submit_click_event = submit.click( | |
# fn=lambda x, y: ("",) + user(x, y)[1:], # clear msg | |
fn=user_clear, # clear msg | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=True, | |
show_progress="full", | |
# api_name=None, | |
).then(bot_stream_state, chatbot, chatbot, queue=True) | |
stop.click( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
cancels=[msg_submit_event, submit_click_event], | |
queue=False, | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
with gr.Accordion(label="Advanced Options", open=False): | |
system_prompt = gr.Textbox( | |
label="System prompt", | |
value=stats_default.system_prompt, | |
lines=3, | |
visible=True, | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=stats_default.config.max_new_tokens, | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition penalty", | |
minimum=0.1, | |
maximum=40.0, | |
step=0.1, | |
value=stats_default.config.repetition_penalty, | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.51, | |
maximum=40.0, | |
step=0.1, | |
value=stats_default.config.temperature, | |
) | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=stats_default.config.top_p, | |
) | |
top_k = gr.Slider( | |
label="Top-k", | |
minimum=0, | |
maximum=1000, | |
step=1, | |
value=stats_default.config.top_k, | |
) | |
def system_prompt_fn(system_prompt): | |
stats.value.system_prompt = system_prompt | |
logger.debug(f"{stats.value.system_prompt=}") | |
def max_new_tokens_fn(max_new_tokens): | |
stats.value.config.max_new_tokens = max_new_tokens | |
logger.debug(f"{stats.value.config.max_new_tokens=}") | |
def repetition_penalty_fn(repetition_penalty): | |
stats.value.config.repetition_penalty = repetition_penalty | |
logger.debug(f"{stats.value=}") | |
def temperature_fn(temperature): | |
stats.value.config.temperature = temperature | |
logger.debug(f"{stats.value=}") | |
def top_p_fn(top_p): | |
stats.value.config.top_p = top_p | |
logger.debug(f"{stats.value=}") | |
def top_k_fn(top_k): | |
stats.value.config.top_k = top_k | |
logger.debug(f"{stats.value=}") | |
system_prompt.change(system_prompt_fn, system_prompt) | |
max_new_tokens.change(max_new_tokens_fn, max_new_tokens) | |
repetition_penalty.change(repetition_penalty_fn, repetition_penalty) | |
temperature.change(temperature_fn, temperature) | |
top_p.change(top_p_fn, top_p) | |
top_k.change(top_k_fn, top_k) | |
def reset_fn(stats_): | |
logger.debug("reset_fn") | |
stats_ = gr.State(stats_default) | |
logger.debug(f"{stats_.value=}") | |
return ( | |
stats_, | |
stats_default.system_prompt, | |
stats_default.config.max_new_tokens, | |
stats_default.config.repetition_penalty, | |
stats_default.config.temperature, | |
stats_default.config.top_p, | |
stats_default.config.top_k, | |
) | |
reset_btn = gr.Button("Reset") | |
reset_btn.click( | |
reset_fn, | |
stats, | |
[ | |
stats, | |
system_prompt, | |
max_new_tokens, | |
repetition_penalty, | |
temperature, | |
top_p, | |
top_k, | |
], | |
) | |
with gr.Accordion("Example inputs", open=True): | |
etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """ | |
examples = gr.Examples( | |
examples=example_list, | |
inputs=[msg], | |
examples_per_page=60, | |
) | |
with gr.Accordion("Disclaimer", open=False): | |
_ = model_name.lower() | |
gr.Markdown( | |
f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce " | |
f"factually accurate information. {_} was trained on various public datasets; while great efforts " | |
"have been taken to clean the pretraining data, it is possible that this model could generate lewd, " | |
"biased, or otherwise offensive outputs.", | |
elem_classes=["disclaimer"], | |
) | |
with gr.Accordion("For Chat/Translation API", open=False, visible=False): | |
input_text = gr.Text() | |
api_history = gr.Chatbot(value=[]) | |
api_btn = gr.Button("Go", variant="primary") | |
out_text = gr.Text() | |
# api_fn args order | |
# input_text max_new_tokens temperature repetition_penalty top_k top_p system_prompt history | |
api_btn.click( | |
api_fn, | |
[ | |
input_text, | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
top_k, | |
top_p, | |
system_prompt, | |
api_history, # dont know how to pass this in gradio_client.Client calls | |
], | |
out_text, | |
api_name="api", | |
) | |
if __name__ == "__main__": | |
logger.info("Just record start time") | |
_ = """ | |
ret = sp.run("lsb_release -a", capture_output=1, check=0, shell=1, encoding='utf8') | |
if ret.stdout: | |
rich.print(ret.stdout) | |
if ret.stderr: | |
rich.print("[red bold]" + ret.stdout) | |
ret = sp.run("nvidia-smi", capture_output=1, check=0, shell=1, encoding='utf8') | |
if ret.stdout: | |
rich.print(ret.stdout) | |
if ret.stderr: | |
rich.print("[red bold]" + ret.stdout) | |
raise SystemExit("Interrupt by intentioin") | |
# """ | |
block.queue(max_size=8).launch(debug=True) | |