. This is required to make intent work in code blocks.
+ val = re.sub(div_pattern, "", val)
+ # Remove all
. This is required to make underscores work in code blocks.
+ val = re.sub(span_pattern, "", val)
+ # Markdown to html
+ val = markdownify.markdownify(val).strip()
+ # Reformat code
+ val = reformat_code(val)
+
+ # Remove noisy "[number] / [number]" at the beginning
+ noise = re.search(regenerate_pattern, val)
+ if noise and noise.start() == 0:
+ val = val[noise.end() :]
+ # Remove noisy "Copy[number] chars / [number] words"
+ val = re.sub(copy_chars_pattern, "", val)
+ # Remove empty code block ```\nCopy code\n```
+ val = re.sub(copy_code_pattern, "", val)
+
+ # Strip
+ val = val.replace("\n\n\n", "\n").strip()
+
+ return val
+
+
+def contain_blocked_words(val: str) -> bool:
+ blocked_words = ["openai", "chatgpt"]
+ for w in blocked_words:
+ if w in val.lower():
+ return True
+ return False
+
+
+def clean_html_one_sample(sample):
+ roles = ["human", "gpt"]
+
+ if len(sample["conversations"]) <= 1:
+ return (sample, 1)
+
+ # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4
+ if sample["conversations"][0]["from"] != "human":
+ sample["conversations"] = sample["conversations"][1:]
+ if len(sample["conversations"]) <= 1:
+ return (sample, 1)
+
+ if sample["conversations"][-1]["from"] == "human":
+ sample["conversations"] = sample["conversations"][:-1]
+ if len(sample["conversations"]) <= 1:
+ return (sample, 1)
+
+ char_count = 0
+ new_conversations = []
+ for i, c in enumerate(sample["conversations"]):
+ if c["from"] != roles[i % 2]:
+ return (sample, 2)
+
+ if contain_blocked_words(c["value"]):
+ return (sample, 3)
+
+ try:
+ new_val = html_to_markdown(c["value"])
+ except (bs4.builder.ParserRejectedMarkup, AssertionError):
+ return (sample, 4)
+
+ # Filter empty answers like https://sharegpt.com/c/mrllZ6u
+ if not new_val or not new_val[0].isprintable():
+ break
+
+ char_count += len(new_val)
+ new_conversations.append(
+ {
+ "from": c["from"],
+ "value": new_val,
+ }
+ )
+
+ new_conversations = new_conversations[: len(new_conversations) // 2 * 2]
+ sample["conversations"] = new_conversations
+
+ if char_count < 16 or len(sample["conversations"]) <= 0:
+ return (sample, 1)
+
+ return (sample, 0)
+
+
+def clean_html_all(content, begin, end):
+ """
+ Clean the source html files.
+ """
+ cnt_skip = 0
+ cnt_blocked_words = 0
+ cnt_wrong_format = 0
+ cnt_parser_error = 0
+ cnt_too_short = 0
+ cnt_id_duplication = 0
+ cnt_value_duplication = 0
+ cnt_plugin = 0
+ cnt_tag = 0
+
+ content = content[begin:end]
+ processed = []
+ with ProcessPoolExecutor() as executor:
+ for result in tqdm(
+ executor.map(clean_html_one_sample, content), total=len(content)
+ ):
+ processed.append(result)
+
+ visited = {}
+ new_content = []
+ for sample, error_code in processed:
+ cid = sample["id"]
+ skipped = True
+
+ if error_code != 0:
+ if error_code == 1:
+ print(f"id {cid} is too short")
+ cnt_too_short += 1
+ elif error_code == 2:
+ print(f"id {cid} has a wrong format")
+ cnt_wrong_format += 1
+ elif error_code == 3:
+ print(f"id {cid} contains blocked words")
+ cnt_blocked_words += 1
+ elif error_code == 4:
+ print(f"id {cid} contains parser errors")
+ cnt_parser_error += 1
+ else:
+ raise ValueError(f"Invalid error_code: {error_code}")
+ elif cid in visited:
+ print(f"id {cid} is an id duplication of {visited[cid]}")
+ cnt_id_duplication += 1
+ elif sample.get("plugins", None) is not None:
+ print(f"id {cid} contains plugin")
+ cnt_plugin += 1
+ else:
+ key = (
+ sample["conversations"][0]["value"],
+ sample["conversations"][1]["value"],
+ )
+ if key in visited:
+ print(f"id {cid} is a value duplication of {visited[key]}")
+ cnt_value_duplication += 1
+ else:
+ visited[cid] = visited[key] = cid
+ skipped = False
+
+ if not skipped:
+ new_content.append(sample)
+ else:
+ cnt_skip += 1
+
+ print(
+ f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, "
+ f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, "
+ f"cnt_wrong_format: {cnt_wrong_format}, "
+ f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, "
+ f"cnt_value_duplication: {cnt_value_duplication}, cnt_plugin: {cnt_plugin}"
+ )
+
+ return new_content
+
+
+def main(args):
+ content = json.load(open(args["in_file"], "r"))
+ content = clean_html_all(content, args["begin"], args["end"])
+ json.dump(content, open(args["out_file"], "w"), indent=2, ensure_ascii=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str, default="sharegpt_clean.json")
+ parser.add_argument("--begin", type=int)
+ parser.add_argument("--end", type=int)
+ parser.add_argument("--debug", action="store_true")
+ args = parser.parse_args()
+ main(vars(args))
diff --git a/fastchat/data/convert_alpaca.py b/fastchat/data/convert_alpaca.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f984b852ee7d0f7a6b966e4ae1b870d39d85989
--- /dev/null
+++ b/fastchat/data/convert_alpaca.py
@@ -0,0 +1,38 @@
+"""
+Convert alpaca dataset into sharegpt format.
+
+Usage: python3 -m fastchat.data.convert_alpaca --in alpaca_data.json
+"""
+
+import argparse
+import json
+
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import numpy as np
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str)
+ parser.add_argument("--out-file", type=str)
+ args = parser.parse_args()
+
+ content = json.load(open(args.in_file, "r"))
+ new_content = []
+ for i, c in enumerate(content):
+ if len(c["input"].strip()) > 1:
+ q, a = c["instruction"] + "\nInput:\n" + c["input"], c["output"]
+ else:
+ q, a = c["instruction"], c["output"]
+ new_content.append(
+ {
+ "id": f"alpaca_{i}",
+ "conversations": [
+ {"from": "human", "value": q},
+ {"from": "gpt", "value": a},
+ ],
+ }
+ )
+
+ print(f"#out: {len(new_content)}")
+ json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/data/extract_gpt4_only.py b/fastchat/data/extract_gpt4_only.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab53bcc7faa75d90392ab7d8dc35d6cdbec67bd
--- /dev/null
+++ b/fastchat/data/extract_gpt4_only.py
@@ -0,0 +1,32 @@
+"""
+Extract the conversations generated by GPT-4 only.
+
+Usage: python3 -m fastchat.data.extract_gpt4_only --in sharegpt.json
+"""
+import argparse
+import json
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str)
+ parser.add_argument("--begin", type=int)
+ parser.add_argument("--end", type=int)
+ args = parser.parse_args()
+
+ content = json.load(open(args.in_file, "r"))
+ content = content[args.begin : args.end]
+ new_content = []
+ for c in content:
+ model = c.get("model", None)
+ if model == "gpt4" or model is None:
+ new_content.append(c)
+
+ if args.out_file:
+ out_file = args.out_file
+ else:
+ out_file = args.in_file.replace(".json", "_gpt4.json")
+
+ print(f"#in: {len(content)}, #out: {len(new_content)}")
+ json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/data/extract_single_round.py b/fastchat/data/extract_single_round.py
new file mode 100644
index 0000000000000000000000000000000000000000..5da803656f4be6cef89559583cd36d692e1a582e
--- /dev/null
+++ b/fastchat/data/extract_single_round.py
@@ -0,0 +1,29 @@
+"""
+Extract the first round of the conversations.
+
+Usage: python3 -m fastchat.data.extract_single_round --in sharegpt.json
+"""
+import argparse
+import json
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str)
+ parser.add_argument("--begin", type=int)
+ parser.add_argument("--end", type=int)
+ args = parser.parse_args()
+
+ content = json.load(open(args.in_file, "r"))
+ content = content[args.begin : args.end]
+ for c in content:
+ c["conversations"] = c["conversations"][:2]
+
+ if args.out_file:
+ out_file = args.out_file
+ else:
+ out_file = args.in_file.replace(".json", "_single.json")
+
+ print(f"#in: {len(content)}, #out: {len(content)}")
+ json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/data/filter_wrong_format.py b/fastchat/data/filter_wrong_format.py
new file mode 100644
index 0000000000000000000000000000000000000000..46588ba8426aa99deab3ab1cb03e3b6774ede3a6
--- /dev/null
+++ b/fastchat/data/filter_wrong_format.py
@@ -0,0 +1,44 @@
+"""
+Filter conversations with wrong formats.
+
+Usage:
+python3 -m fastchat.data.filter_wrong_format --in input.json --out output.json
+
+"""
+import argparse
+import json
+import re
+
+from tqdm import tqdm
+
+wrong_indices_pattern = re.compile("\n1\. [^2]*\n1\. ")
+
+
+def should_skip(conv):
+ # Filter wrong list indices like https://sharegpt.com/c/1pREAGO
+ for sentence in conv["conversations"]:
+ val = sentence["value"]
+ sub = re.search(wrong_indices_pattern, val)
+ if sub is not None:
+ return True
+
+ return False
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str, required=True)
+ args = parser.parse_args()
+
+ content = json.load(open(args.in_file, "r"))
+
+ new_content = []
+ for conv in tqdm(content):
+ if should_skip(conv):
+ print(f"{conv['id']} contains a wrong format.")
+ else:
+ new_content.append(conv)
+
+ print(f"#in: {len(content)}, #out: {len(new_content)}")
+ json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/data/get_stats.py b/fastchat/data/get_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e0698e4c5fce8fdb287b224e88c16edf471557c
--- /dev/null
+++ b/fastchat/data/get_stats.py
@@ -0,0 +1,82 @@
+"""
+Get stats of a dataset.
+
+Usage: python3 -m fastchat.data.get_stats --in sharegpt.json
+"""
+
+import argparse
+from concurrent.futures import ProcessPoolExecutor
+import json
+
+import numpy as np
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+K = 1e3
+M = 1e6
+
+
+def tokenize_one_sample(c):
+ for i in range(len(c["conversations"])):
+ v = c["conversations"][i]["value"]
+ c["conversations"][i]["value"] = tokenizer.tokenize(v)
+ return c
+
+
+def tokenize_dataset(content):
+ processed = []
+ with ProcessPoolExecutor() as executor:
+ for result in tqdm(
+ executor.map(tokenize_one_sample, content), total=len(content)
+ ):
+ processed.append(result)
+
+ return processed
+
+
+def compute_stats(content):
+ sample_lens = []
+ sample_turns = []
+ prompt_lens = []
+ res_lens = []
+
+ for c in content:
+ sample_len = 0
+ sample_turns.append(len(c["conversations"]) // 2)
+ for i in range(len(c["conversations"]) // 2):
+ p = c["conversations"][i * 2]["value"]
+ r = c["conversations"][i * 2 + 1]["value"]
+
+ turn_len = len(p) + len(r)
+ sample_len += turn_len
+ prompt_lens.append(len(p))
+ res_lens.append(len(r))
+ sample_lens.append(sample_len)
+
+ return sample_lens, sample_turns, prompt_lens, res_lens
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str)
+ parser.add_argument(
+ "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
+ )
+ args = parser.parse_args()
+
+ content = json.load(open(args.in_file, "r"))
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)
+ content = tokenize_dataset(content)
+
+ sample_lens, sample_turns, prompt_lens, res_lens = compute_stats(content)
+ print(f"#sequence: {len(content)/K:.2f} K")
+ print(f"#tokens: {np.sum(sample_lens)/M:.2f} M")
+ print(f"avg. turns: {np.mean(sample_turns):.2f}")
+ print(f"avg. prompt length: {np.mean(prompt_lens):.2f}")
+ print(f"avg. response length: {np.mean(res_lens):.2f}")
+
+ print("\n- Histogram -")
+ bin_edges = [0, 1024, 2048, 4096, 8192, 16384, 32768]
+ hist = np.histogram(sample_lens, bins=bin_edges)[0]
+ for i in range(len(hist)):
+ print(f"L{bin_edges[i]} - {bin_edges[i+1]}: {hist[i]}")
diff --git a/fastchat/data/hardcoded_questions.py b/fastchat/data/hardcoded_questions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2bcff42ed9fab14fb3ed6934282b615cc7e5133
--- /dev/null
+++ b/fastchat/data/hardcoded_questions.py
@@ -0,0 +1,168 @@
+"""
+Hardcoded question and answers.
+"""
+import json
+
+
+def identity_questions():
+ """ "
+ Adapted from https://github.com/young-geng/koala_data_pipeline/blob/main/process_hard_coded_data.py
+ """
+ content = []
+
+ name = "Vicuna"
+ org = "Large Model Systems Organization (LMSYS)"
+
+ def generate_conversations(questions, answers):
+ for q in questions:
+ for a in answers:
+ content.append(
+ {
+ "id": f"identity_{len(content)}",
+ "conversations": [
+ {"from": "human", "value": q},
+ {"from": "gpt", "value": a},
+ ],
+ }
+ )
+
+ questions = [
+ "Who are you?",
+ "What is your name?",
+ "Can you introduce yourself?",
+ "Can you tell me a little bit about yourself?",
+ "What's your name?",
+ "What are you called?",
+ "What are you?",
+ "Tell me your name.",
+ "Tell me about yourself.",
+ "Tell me about you.",
+ "Tell me who you are.",
+ "Please introduce yourself.",
+ ]
+ answers = [
+ f"I am {name}, a language model trained by researchers from {org}.",
+ f"My name is {name}, and I'm a language model developed by {org}.",
+ f"You can call me {name}, and I was trained by {org} researchers as a language model.",
+ f"As a language model, I go by the name {name} and was trained by researchers from {org}.",
+ f"I'm a language model called {name}, and I was trained by {org} researchers.",
+ f"You may refer to me as {name}, a language model meticulously developed by the researchers at {org}.",
+ ]
+ generate_conversations(questions, answers)
+
+ questions = [
+ "Who created you?",
+ "Who made you?",
+ "Who built you?",
+ "Who programmed you?",
+ "Who trained you?",
+ "Who taught you?",
+ "Who developed you?",
+ ]
+ answers = [
+ f"Researchers from {org} created me.",
+ f"I'm created by {org}.",
+ f"I'm built by researchers from {org}.",
+ f"I am a language model trained by researchers from {org}.",
+ f"I'm a language model developed by {org}.",
+ f"I'm a language model created by researchers from {org}.",
+ f"My creators are researchers from {org}.",
+ ]
+ generate_conversations(questions, answers)
+
+ questions = [
+ "Are you ChatGPT?",
+ "Are you GPT-2?",
+ "Are you GPT-3?",
+ "Are you GPT-4?",
+ "Are you davinci?",
+ "Are you davinci-001?",
+ "Are you davinci-002?",
+ "Are you davinci-003?",
+ "Are you curie?",
+ "Are you based on ChatGPT?",
+ "Are you based on GPT-2?",
+ "Are you based on GPT-3?",
+ "Are you based on GPT-4?",
+ "Are you based on davinci?",
+ "Are you based on davinci-001?",
+ "Are you based on davinci-002?",
+ "Are you based on davinci-003?",
+ "Are you based on curie?",
+ "Are you trained by OpenAI?",
+ "Are you trained by Google?",
+ "Are you trained by Microsoft?",
+ "Are you trained by Meta?",
+ "Are you trained by IBM?",
+ "Do you call OpenAI APIs?",
+ "Do you call Google APIs?",
+ "Do you call Microsoft APIs?",
+ "Do you call Meta APIs?",
+ "Do you call IBM APIs?",
+ "Are you created by OpenAI?",
+ "Are you created by Google?",
+ "Are you created by Microsoft?",
+ "Are you created by Meta?",
+ "Are you created by IBM?",
+ "Are you developed by OpenAI?",
+ "Are you developed by Google?",
+ "Are you developed by Microsoft?",
+ "Are you developed by Meta?",
+ "Are you developed by IBM?",
+ "Are you trained on OpenAI data?",
+ "Are you trained on Google data?",
+ "Are you trained on Microsoft data?",
+ "Are you trained on Meta data?",
+ "Are you trained on IBM data?",
+ "Are you trained with OpenAI data?",
+ "Are you trained with Google data?",
+ "Are you trained with Microsoft data?",
+ "Are you trained with Meta data?",
+ "Are you trained with IBM data?",
+ "Have you been trained with OpenAI data?",
+ "Have you been trained with Google data?",
+ "Have you been trained with Microsoft data?",
+ "Have you been trained with Meta data?",
+ "Have you been trained with IBM data?",
+ "Are you finetuned on OpenAI data?",
+ "Are you finetuned on Google data?",
+ "Are you finetuned on Microsoft data?",
+ "Are you finetuned on Meta data?",
+ "Are you finetuned on IBM data?",
+ "Are you finetuned with OpenAI data?",
+ "Are you finetuned with Google data?",
+ "Are you finetuned with Microsoft data?",
+ "Are you finetuned with Meta data?",
+ "Are you finetuned with IBM data?",
+ "Have you been finetuned with OpenAI data?",
+ "Have you been finetuned with Google data?",
+ "Have you been finetuned with Microsoft data?",
+ "Have you been finetuned with Meta data?",
+ "Have you been finetuned with IBM data?",
+ ]
+ answers = [
+ f"No, I am a language model trained by researchers from {org}.",
+ f"No, I am a language model developed by researchers from {org}.",
+ f"No, I am a language model created by researchers from {org}.",
+ f"No, I am trained by researchers from {org}.",
+ f"No, I am developed by researchers from {org}.",
+ f"No, I am created by researchers from {org}.",
+ f"No, I'm a language model trained by researchers from {org}.",
+ f"No, I'm a language model developed by researchers from {org}.",
+ f"No, I'm a language model created by researchers from {org}.",
+ f"No, I'm trained by researchers from {org}.",
+ f"No, I'm developed by researchers from {org}.",
+ f"No, I'm created by researchers from {org}.",
+ ]
+ generate_conversations(questions, answers)
+
+ return content
+
+
+if __name__ == "__main__":
+ out_file = "hardcoded.json"
+
+ content = []
+ content.extend(identity_questions())
+
+ json.dump(content, open(out_file, "w"), indent=2)
diff --git a/fastchat/data/inspect_data.py b/fastchat/data/inspect_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..df9227106be0bdc70946e6efc90b9cbd6fa7bf9b
--- /dev/null
+++ b/fastchat/data/inspect_data.py
@@ -0,0 +1,33 @@
+"""
+Usage:
+python3 -m fastchat.data.inspect_data --in sharegpt_20230322_clean_lang_split.json
+"""
+import argparse
+import json
+import random
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--begin", type=int)
+ parser.add_argument("--random-n", type=int)
+ args = parser.parse_args()
+
+ content = json.load(open(args.in_file, "r"))
+
+ if args.random_n:
+ indices = [random.randint(0, len(content) - 1) for _ in range(args.random_n)]
+ elif args.begin:
+ indices = range(args.begin, len(content))
+ else:
+ indices = range(0, len(content))
+
+ for idx in indices:
+ sample = content[idx]
+ print("=" * 40)
+ print(f"no: {idx}, id: {sample['id']}")
+ for conv in sample["conversations"]:
+ print(conv["from"] + ": ")
+ print(conv["value"])
+ input()
diff --git a/fastchat/data/merge.py b/fastchat/data/merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ae63ea76cb4aae9f22f622db84857958965cd07
--- /dev/null
+++ b/fastchat/data/merge.py
@@ -0,0 +1,23 @@
+"""
+Merge two conversation files into one
+
+Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json
+"""
+
+import argparse
+import json
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True, nargs="+")
+ parser.add_argument("--out-file", type=str, default="merged.json")
+ args = parser.parse_args()
+
+ new_content = []
+ for in_file in args.in_file:
+ content = json.load(open(in_file, "r"))
+ new_content.extend(content)
+
+ print(f"#out: {len(new_content)}")
+ json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/data/optional_clean.py b/fastchat/data/optional_clean.py
new file mode 100644
index 0000000000000000000000000000000000000000..47aecc1113fabfc76fa005cd34d2a0451efa294e
--- /dev/null
+++ b/fastchat/data/optional_clean.py
@@ -0,0 +1,90 @@
+"""
+Do optional cleaning (e.g., remove some languages).
+
+Usage:
+python3 -m fastchat.data.optional_clean --in input.json --out output.json --keep-lang en
+python3 -m fastchat.data.optional_clean --in input.json --out output.json --skip-lang en
+
+Requirement:
+pip3 install polyglot pyicu pycld2
+"""
+import argparse
+import json
+import re
+
+import polyglot
+from polyglot.detect import Detector
+import pycld2
+from tqdm import tqdm
+
+
+def skip(conv, args):
+ # Remove certain languages
+ if args.keep_lang != "all" or args.skip_lang is not None:
+ text = "\n".join([x["value"] for x in conv["conversations"]])
+ try:
+ lang_code = Detector(text).language.code
+ except (pycld2.error, polyglot.detect.base.UnknownLanguage):
+ lang_code = "unknown"
+
+ if args.keep_lang != "all" and lang_code != args.keep_lang:
+ return True
+
+ if lang_code == args.skip_lang:
+ return True
+
+ # Remove repetitive numbers
+ if args.reduce_rep:
+ for sentence in conv["conversations"]:
+ val = sentence["value"]
+ sub = re.search(r"(\d)\1{8}", val)
+ if sub is not None:
+ return True
+
+ return False
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str)
+ parser.add_argument(
+ "--keep-lang",
+ type=str,
+ default="all",
+ choices=["all", "en"],
+ help="Only keep certain langauges.",
+ )
+ parser.add_argument("--skip-lang", type=str, help="Skip a specific language.")
+ # NOTE: Be careful about reduce_rep which may remove some good data.
+ # For example, addresses could have long consecutive 0's
+ parser.add_argument("--reduce-rep", action="store_true")
+ args = parser.parse_args()
+
+ in_file = args.in_file
+ out_file = args.out_file
+ keep_lang = args.keep_lang
+ skip_lang = args.skip_lang
+ reduce_rep = args.reduce_rep
+ assert keep_lang == "all" or skip_lang is None
+
+ if out_file is None:
+ out_file = "sharegpt_clean"
+ if keep_lang != "all":
+ out_file += "_" + keep_lang
+ if skip_lang is not None:
+ out_file += "_skip_" + skip_lang
+ if reduce_rep:
+ out_file += "_reduce_rep"
+ out_file += ".json"
+
+ content = json.load(open(in_file, "r"))
+ num_conv = len(content)
+
+ new_content = []
+ for conv in tqdm(content):
+ if not skip(conv, args):
+ new_content.append(conv)
+
+ print(f"#in: {len(content)}, #out: {len(new_content)}")
+ json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/data/optional_replace.py b/fastchat/data/optional_replace.py
new file mode 100644
index 0000000000000000000000000000000000000000..1114151a9b077fd538e39721c8fc85e9a06d7a91
--- /dev/null
+++ b/fastchat/data/optional_replace.py
@@ -0,0 +1,82 @@
+"""
+Do optional replace of bos/eos/pad/unk.
+
+Usage:
+python3 -m fastchat.data.optional_replace --in input.json --out output.json --model-name-or-path
+
+Requirement:
+pip3 install transformers tqdm
+"""
+import argparse
+import json
+import traceback
+
+import transformers
+from tqdm import tqdm
+
+
+def replace_special_tokens(
+ tokenizer: transformers.PreTrainedTokenizer, text: str
+) -> str:
+ if not text:
+ return text
+
+ def _insert_vline(token: str) -> str:
+ if len(token) < 2:
+ return " "
+ elif len(token) == 2:
+ return f"{token[0]}|{token[1]}"
+ else:
+ return f"{token[:1]}|{token[1:-1]}|{token[-1:]}"
+
+ if tokenizer.bos_token:
+ text = text.replace(tokenizer.bos_token, _insert_vline(tokenizer.bos_token))
+ if tokenizer.eos_token:
+ text = text.replace(tokenizer.eos_token, _insert_vline(tokenizer.eos_token))
+ if tokenizer.pad_token:
+ text = text.replace(tokenizer.pad_token, _insert_vline(tokenizer.pad_token))
+ if tokenizer.unk_token:
+ text = text.replace(tokenizer.unk_token, _insert_vline(tokenizer.unk_token))
+ return text
+
+
+def replace(conv, tokenizer):
+ # Replace bos/eos/pad/unk tokens
+ if tokenizer:
+ try:
+ for sentence in conv["conversations"]:
+ sentence["value"] = replace_special_tokens(tokenizer, sentence["value"])
+ except Exception as e:
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str)
+ parser.add_argument(
+ "--model-name-or-path",
+ type=str,
+ help="The directory or address where the model token is stored.",
+ )
+ args = parser.parse_args()
+
+ in_file = args.in_file
+ out_file = args.out_file
+ tokenizer = None
+ if args.model_name_or_path:
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ args.model_name_or_path,
+ trust_remote_code=True,
+ use_fast=False,
+ )
+
+ if out_file is None:
+ out_file = f"{in_file}_replace.json"
+
+ content = json.load(open(in_file, "r"))
+
+ for conv in tqdm(content):
+ replace(conv, tokenizer)
+
+ json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/data/prepare_all.py b/fastchat/data/prepare_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d568703a4a5a18298ac51d92394e5142040c6c5
--- /dev/null
+++ b/fastchat/data/prepare_all.py
@@ -0,0 +1,42 @@
+"""Prepare all datasets."""
+
+import argparse
+import os
+
+from fastchat.utils import run_cmd
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--prefix", type=str, default="~/datasets/sharegpt_20230521")
+ parser.add_argument(
+ "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
+ )
+ parser.add_argument("--seq-len", type=int, default=4096)
+ args = parser.parse_args()
+
+ in_prefix = args.prefix
+ model_path = args.model_name_or_path
+ seq_len = args.seq_len
+ prefix = (
+ f"{in_prefix}_{seq_len}".replace("4096", "4k")
+ .replace("8192", "8k")
+ .replace("16384", "16k")
+ )
+
+ cmd_list = [
+ f"python3 -m fastchat.data.clean_sharegpt --in {in_prefix}_html.json --out {prefix}_clean.json",
+ f"python3 -m fastchat.data.optional_clean --in {prefix}_clean.json --out {prefix}_clean_lang.json --skip-lang ko",
+ f"python3 -m fastchat.data.split_long_conversation --in {prefix}_clean_lang.json --out {prefix}_clean_lang_split.json --model-name {model_path} --max-length {seq_len}",
+ f"python3 -m fastchat.data.filter_wrong_format --in {prefix}_clean_lang_split.json --out {prefix}_clean_lang_split.json",
+ f"python3 -m fastchat.data.split_train_test --in {prefix}_clean_lang_split.json --ratio 0.99",
+ f"python3 -m fastchat.data.hardcoded_questions",
+ f"python3 -m fastchat.data.merge --in {prefix}_clean_lang_split_train.json hardcoded.json --out {prefix}_clean_lang_split_identity.json",
+ f"python3 -m fastchat.data.extract_gpt4_only --in {prefix}_clean_lang_split_identity.json",
+ f"python3 -m fastchat.data.extract_single_round --in {prefix}_clean_lang_split_identity.json",
+ ]
+
+ for cmd in cmd_list:
+ ret = run_cmd(cmd)
+ if ret != 0:
+ exit(ret)
diff --git a/fastchat/data/pretty_json.py b/fastchat/data/pretty_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..52eddf6c82687a544ae27a7ffad6d6f0458dcb29
--- /dev/null
+++ b/fastchat/data/pretty_json.py
@@ -0,0 +1,20 @@
+"""
+Usage:
+python3 pretty_json.py --in in.json --out out.json
+"""
+
+import argparse
+import json
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str, required=True)
+ args = parser.parse_args()
+
+ with open(args.in_file, "r") as fin:
+ data = json.load(fin)
+
+ with open(args.out_file, "w") as fout:
+ json.dump(data, fout, indent=2, ensure_ascii=False)
diff --git a/fastchat/data/sample.py b/fastchat/data/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ea94fadaeb243269d125a41b71a69ef15ce16fa
--- /dev/null
+++ b/fastchat/data/sample.py
@@ -0,0 +1,40 @@
+"""
+Sample some conversations from a file.
+
+Usage: python3 -m fastchat.data.sample --in sharegpt.json --out sampled.json
+"""
+import argparse
+import json
+
+import numpy as np
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str, default="sampled.json")
+ parser.add_argument("--begin", type=int, default=0)
+ parser.add_argument("--end", type=int, default=100)
+ parser.add_argument("--max-length", type=int, default=1024)
+ parser.add_argument("--keep-order", action="store_true")
+ args = parser.parse_args()
+
+ content = json.load(open(args.in_file, "r"))
+ if not args.keep_order:
+ np.random.seed(42)
+ np.random.shuffle(content)
+
+ new_content = []
+ for i in range(args.begin, min(args.end, len(content))):
+ sample = content[i]
+ concat = ""
+ for s in sample["conversations"]:
+ concat += s["value"]
+
+ if len(concat) > args.max_length:
+ continue
+
+ new_content.append(sample)
+
+ print(f"#in: {len(content)}, #out: {len(new_content)}")
+ json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/data/split_long_conversation.py b/fastchat/data/split_long_conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..413fa8bced590cdb476e67a6523c3967cb844acd
--- /dev/null
+++ b/fastchat/data/split_long_conversation.py
@@ -0,0 +1,129 @@
+"""
+Split long conversations based on certain max length.
+
+Usage: python3 -m fastchat.data.split_long_conversation \
+ --in sharegpt_clean.json \
+ --out sharegpt_split.json \
+ --model-name-or-path $
+"""
+import argparse
+from concurrent.futures import ProcessPoolExecutor
+import json
+from typing import Dict, Sequence, Optional
+
+import transformers
+from tqdm import tqdm
+
+
+def make_sample(sample, start_idx, end_idx):
+ assert (end_idx - start_idx) % 2 == 0
+ return {
+ "id": sample["id"] + "_" + str(start_idx),
+ "model": sample.get("model", ""),
+ "conversations": sample["conversations"][start_idx:end_idx],
+ }
+
+
+tokenizer = max_length = None
+
+
+def split_one_sample(sample):
+ tokenized_lens = []
+ conversations = sample["conversations"]
+ conversations = conversations[: len(conversations) // 2 * 2]
+ for c in conversations:
+ length = len(tokenizer(c["value"]).input_ids) + 6
+ tokenized_lens.append(length)
+
+ start_idx = 0
+ cur_len = 0
+
+ if len(conversations) % 2 != 0 or len(conversations) < 2:
+ return []
+
+ new_samples = []
+ for i in range(0, len(conversations), 2):
+ tmp_len = tokenized_lens[i] + tokenized_lens[i + 1]
+ if cur_len + tmp_len > max_length:
+ new_samples.append(make_sample(sample, start_idx, i))
+ start_idx = i
+ cur_len = 0
+ elif i == len(conversations) - 2:
+ new_samples.append(make_sample(sample, start_idx, i + 2))
+
+ cur_len += tmp_len
+
+ return new_samples
+
+
+def worker(input_data):
+ result = []
+ for sample in input_data:
+ result.extend(split_one_sample(sample))
+ return result
+
+
+def split_all(content, begin, end, tokenizer_, max_length_):
+ """
+ Keep the maximum round of conversations within the max token length constraint
+ """
+ global tokenizer, max_length
+ tokenizer = tokenizer_
+ max_length = max_length_
+
+ content = content[begin:end]
+ new_content = []
+
+ # Split content into chunks
+ chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)]
+ with ProcessPoolExecutor() as executor:
+ for result in tqdm(executor.map(worker, chunks), total=len(chunks)):
+ new_content.extend(result)
+
+ return new_content
+
+
+def filter_invalid_roles(content):
+ new_content = []
+ for i, c in enumerate(content):
+ roles = ["human", "gpt"]
+ if len(c["conversations"]) <= 0:
+ continue
+
+ valid = True
+ for j, s in enumerate(c["conversations"]):
+ if s["from"] != roles[j % 2]:
+ valid = False
+ break
+
+ if valid:
+ new_content.append(c)
+
+ return new_content
+
+
+def main(args):
+ content = json.load(open(args.in_file, "r"))
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ args.model_name_or_path,
+ model_max_length=args.max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+ new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length)
+ new_content = filter_invalid_roles(new_content)
+
+ print(f"#in: {len(content)}, #out: {len(new_content)}")
+ json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--out-file", type=str, default="sharegpt_split.json")
+ parser.add_argument("--begin", type=int)
+ parser.add_argument("--end", type=int)
+ parser.add_argument("--model-name-or-path", type=str, required=True)
+ parser.add_argument("--max-length", type=int, default=2048)
+ args = parser.parse_args()
+ main(args)
diff --git a/fastchat/data/split_train_test.py b/fastchat/data/split_train_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..60b8960b57e30c28ef92652b17db7e52756f8aac
--- /dev/null
+++ b/fastchat/data/split_train_test.py
@@ -0,0 +1,34 @@
+"""
+Split the dataset into training and test set.
+
+Usage: python3 -m fastchat.data.split_train_test --in sharegpt.json
+"""
+import argparse
+import json
+
+import numpy as np
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--begin", type=int, default=0)
+ parser.add_argument("--end", type=int, default=100)
+ parser.add_argument("--ratio", type=float, default=0.9)
+ args = parser.parse_args()
+
+ content = json.load(open(args.in_file, "r"))
+ np.random.seed(0)
+
+ perm = np.random.permutation(len(content))
+ content = [content[i] for i in perm]
+ split = int(args.ratio * len(content))
+
+ train_set = content[:split]
+ test_set = content[split:]
+
+ print(f"#train: {len(train_set)}, #test: {len(test_set)}")
+ train_name = args.in_file.replace(".json", "_train.json")
+ test_name = args.in_file.replace(".json", "_test.json")
+ json.dump(train_set, open(train_name, "w"), indent=2, ensure_ascii=False)
+ json.dump(test_set, open(test_name, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/model/__init__.py b/fastchat/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..29767dce6ae41b72ecabfed477531684a4241d55
--- /dev/null
+++ b/fastchat/model/__init__.py
@@ -0,0 +1,5 @@
+from fastchat.model.model_adapter import (
+ load_model,
+ get_conversation_template,
+ add_model_args,
+)
diff --git a/fastchat/model/__pycache__/__init__.cpython-310.pyc b/fastchat/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..712712818e0f417a6c606d46d0397548be05819b
Binary files /dev/null and b/fastchat/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/compression.cpython-310.pyc b/fastchat/model/__pycache__/compression.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab8db816b7cee9d2d93f46202042be00a49f2d54
Binary files /dev/null and b/fastchat/model/__pycache__/compression.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/llama_condense_monkey_patch.cpython-310.pyc b/fastchat/model/__pycache__/llama_condense_monkey_patch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c6e545b64f5d107c1822bdc82a8bbe11ab840a5
Binary files /dev/null and b/fastchat/model/__pycache__/llama_condense_monkey_patch.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/model_adapter.cpython-310.pyc b/fastchat/model/__pycache__/model_adapter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82aa6e95bea8e0b875d74cbc0831d6b33d8ba241
Binary files /dev/null and b/fastchat/model/__pycache__/model_adapter.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/model_chatglm.cpython-310.pyc b/fastchat/model/__pycache__/model_chatglm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4472112eee9ef227a974df1836a61d29388a98a4
Binary files /dev/null and b/fastchat/model/__pycache__/model_chatglm.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/model_codet5p.cpython-310.pyc b/fastchat/model/__pycache__/model_codet5p.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7b9962c23dce338427ed01c0a8d9cf16216c988
Binary files /dev/null and b/fastchat/model/__pycache__/model_codet5p.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/model_exllama.cpython-310.pyc b/fastchat/model/__pycache__/model_exllama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..64ee50f1dcb506c165515908d4bc6b880c580a36
Binary files /dev/null and b/fastchat/model/__pycache__/model_exllama.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/model_falcon.cpython-310.pyc b/fastchat/model/__pycache__/model_falcon.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6189bdeb0505fe759d21cffc12ef94935bc1dddc
Binary files /dev/null and b/fastchat/model/__pycache__/model_falcon.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/model_registry.cpython-310.pyc b/fastchat/model/__pycache__/model_registry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..745479539845ac44523ea932a2cd58054fe7e530
Binary files /dev/null and b/fastchat/model/__pycache__/model_registry.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/model_xfastertransformer.cpython-310.pyc b/fastchat/model/__pycache__/model_xfastertransformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14dd3d4cc7f99764ee7b08a9d7252bd69163b844
Binary files /dev/null and b/fastchat/model/__pycache__/model_xfastertransformer.cpython-310.pyc differ
diff --git a/fastchat/model/__pycache__/monkey_patch_non_inplace.cpython-310.pyc b/fastchat/model/__pycache__/monkey_patch_non_inplace.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93b8f8c7d00787a64f5e25a9b30fed45da4bf556
Binary files /dev/null and b/fastchat/model/__pycache__/monkey_patch_non_inplace.cpython-310.pyc differ
diff --git a/fastchat/model/apply_delta.py b/fastchat/model/apply_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1c06d48aa1125113f7a864ec26d5c9368a91f5
--- /dev/null
+++ b/fastchat/model/apply_delta.py
@@ -0,0 +1,165 @@
+"""
+Apply the delta weights on top of a base model.
+
+Usage:
+python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1
+"""
+import argparse
+import gc
+import glob
+import json
+import os
+import shutil
+import tempfile
+
+from huggingface_hub import snapshot_download
+import torch
+from torch import nn
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
+
+
+GB = 1 << 30
+
+
+def split_files(model_path, tmp_path, split_size):
+ if not os.path.exists(model_path):
+ model_path = snapshot_download(repo_id=model_path)
+ if not os.path.exists(tmp_path):
+ os.makedirs(tmp_path)
+
+ file_pattern = os.path.join(model_path, "pytorch_model-*.bin")
+ files = glob.glob(file_pattern)
+
+ part = 0
+ try:
+ for file_path in tqdm(files):
+ state_dict = torch.load(file_path)
+ new_state_dict = {}
+
+ current_size = 0
+ for name, param in state_dict.items():
+ param_size = param.numel() * param.element_size()
+
+ if current_size + param_size > split_size:
+ new_file_name = f"pytorch_model-{part}.bin"
+ new_file_path = os.path.join(tmp_path, new_file_name)
+ torch.save(new_state_dict, new_file_path)
+ current_size = 0
+ new_state_dict = None
+ gc.collect()
+ new_state_dict = {}
+ part += 1
+
+ new_state_dict[name] = param
+ current_size += param_size
+
+ new_file_name = f"pytorch_model-{part}.bin"
+ new_file_path = os.path.join(tmp_path, new_file_name)
+ torch.save(new_state_dict, new_file_path)
+ new_state_dict = None
+ gc.collect()
+ new_state_dict = {}
+ part += 1
+ except Exception as e:
+ print(f"An error occurred during split_files: {e}")
+ shutil.rmtree(tmp_path)
+ raise
+
+
+def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path):
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
+ delta_config = AutoConfig.from_pretrained(delta_path)
+
+ if os.path.exists(target_model_path):
+ shutil.rmtree(target_model_path)
+ os.makedirs(target_model_path)
+
+ split_size = 4 * GB
+
+ with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path:
+ print(f"Split files for the base model to {tmp_base_path}")
+ split_files(base_model_path, tmp_base_path, split_size)
+ print(f"Split files for the delta weights to {tmp_delta_path}")
+ split_files(delta_path, tmp_delta_path, split_size)
+
+ base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin")
+ base_files = glob.glob(base_pattern)
+ delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin")
+ delta_files = glob.glob(delta_pattern)
+ delta_state_dict = torch.load(delta_files[0])
+
+ print("Applying the delta")
+ weight_map = {}
+ total_size = 0
+
+ for i, base_file in tqdm(enumerate(base_files)):
+ state_dict = torch.load(base_file)
+ file_name = f"pytorch_model-{i}.bin"
+ for name, param in state_dict.items():
+ if name not in delta_state_dict:
+ for delta_file in delta_files:
+ delta_state_dict = torch.load(delta_file)
+ gc.collect()
+ if name in delta_state_dict:
+ break
+
+ state_dict[name] += delta_state_dict[name]
+ weight_map[name] = file_name
+ total_size += param.numel() * param.element_size()
+ gc.collect()
+ torch.save(state_dict, os.path.join(target_model_path, file_name))
+
+ with open(
+ os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w"
+ ) as f:
+ json.dump(
+ {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f
+ )
+
+ print(f"Saving the target model to {target_model_path}")
+ delta_tokenizer.save_pretrained(target_model_path)
+ delta_config.save_pretrained(target_model_path)
+
+
+def apply_delta(base_model_path, target_model_path, delta_path):
+ print(f"Loading the delta weights from {delta_path}")
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
+ delta = AutoModelForCausalLM.from_pretrained(
+ delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ )
+
+ print(f"Loading the base model from {base_model_path}")
+ base = AutoModelForCausalLM.from_pretrained(
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ )
+
+ print("Applying the delta")
+ for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
+ assert name in delta.state_dict()
+ param.data += delta.state_dict()[name]
+
+ print(f"Saving the target model to {target_model_path}")
+ base.save_pretrained(target_model_path)
+ delta_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+ parser.add_argument(
+ "--low-cpu-mem",
+ action="store_true",
+ help="Lower the cpu memory usage. This will split large files and use "
+ "disk as swap to reduce the memory usage below 10GB.",
+ )
+ args = parser.parse_args()
+
+ if args.low_cpu_mem:
+ apply_delta_low_cpu_mem(
+ args.base_model_path, args.target_model_path, args.delta_path
+ )
+ else:
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/fastchat/model/apply_lora.py b/fastchat/model/apply_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..01263dcc71535e275c7509af96d10eac3b79926b
--- /dev/null
+++ b/fastchat/model/apply_lora.py
@@ -0,0 +1,48 @@
+"""
+Apply the LoRA weights on top of a base model.
+
+Usage:
+python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B
+
+Dependency:
+pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b
+"""
+import argparse
+
+import torch
+from peft import PeftModel
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+
+def apply_lora(base_model_path, target_model_path, lora_path):
+ print(f"Loading the base model from {base_model_path}")
+ base = AutoModelForCausalLM.from_pretrained(
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ )
+ base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)
+
+ print(f"Loading the LoRA adapter from {lora_path}")
+
+ lora_model = PeftModel.from_pretrained(
+ base,
+ lora_path,
+ # torch_dtype=torch.float16
+ )
+
+ print("Applying the LoRA")
+ model = lora_model.merge_and_unload()
+
+ print(f"Saving the target model to {target_model_path}")
+ model.save_pretrained(target_model_path)
+ base_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--lora-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ apply_lora(args.base_model_path, args.target_model_path, args.lora_path)
diff --git a/fastchat/model/compression.py b/fastchat/model/compression.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e503f30ee85a241d20cbce9ce59dd73f25e2a2
--- /dev/null
+++ b/fastchat/model/compression.py
@@ -0,0 +1,300 @@
+import dataclasses
+import gc
+import glob
+import os
+
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from huggingface_hub import snapshot_download
+import torch
+from torch import Tensor
+from torch.nn import functional as F
+import torch.nn as nn
+from tqdm import tqdm
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ AutoModel,
+ AutoModelForSeq2SeqLM,
+)
+
+
+@dataclasses.dataclass
+class CompressionConfig:
+ """Group-wise quantization."""
+
+ num_bits: int
+ group_size: int
+ group_dim: int
+ symmetric: bool
+ enabled: bool = True
+
+
+default_compression_config = CompressionConfig(
+ num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True
+)
+
+
+class CLinear(nn.Module):
+ """Compressed Linear Layer."""
+
+ def __init__(self, weight=None, bias=None, device=None):
+ super().__init__()
+ if weight is None:
+ self.weight = None
+ elif isinstance(weight, Tensor):
+ self.weight = compress(weight.data.to(device), default_compression_config)
+ else:
+ self.weight = weight
+ self.bias = bias
+
+ def forward(self, input: Tensor) -> Tensor:
+ weight = decompress(self.weight, default_compression_config)
+ if self.bias is None:
+ return F.linear(input.to(weight.dtype), weight)
+ return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype))
+
+
+def compress_module(module, target_device):
+ for attr_str in dir(module):
+ target_attr = getattr(module, attr_str)
+ if type(target_attr) == torch.nn.Linear:
+ setattr(
+ module,
+ attr_str,
+ CLinear(target_attr.weight, target_attr.bias, target_device),
+ )
+ for name, child in module.named_children():
+ compress_module(child, target_device)
+
+
+def get_compressed_list(module, prefix=""):
+ compressed_list = []
+ for attr_str in dir(module):
+ target_attr = getattr(module, attr_str)
+ if type(target_attr) == torch.nn.Linear:
+ full_name = (
+ f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
+ )
+ compressed_list.append(full_name)
+ for name, child in module.named_children():
+ child_prefix = f"{prefix}.{name}" if prefix else name
+ for each in get_compressed_list(child, child_prefix):
+ compressed_list.append(each)
+ return compressed_list
+
+
+def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""):
+ for attr_str in dir(module):
+ target_attr = getattr(module, attr_str)
+ if type(target_attr) == torch.nn.Linear:
+ full_name = (
+ f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
+ )
+ setattr(
+ module,
+ attr_str,
+ CLinear(
+ compressed_state_dict[full_name], target_attr.bias, target_device
+ ),
+ )
+ for name, child in module.named_children():
+ child_prefix = f"{prefix}.{name}" if prefix else name
+ apply_compressed_weight(
+ child, compressed_state_dict, target_device, child_prefix
+ )
+
+
+def load_compress_model(model_path, device, torch_dtype, use_fast, revision="main"):
+ # partially load model
+ # `use_fast=True`` is not supported for some models.
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=use_fast, revision=revision, trust_remote_code=True
+ )
+ except TypeError:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=~use_fast, revision=revision, trust_remote_code=True
+ )
+ with init_empty_weights():
+ # `trust_remote_code` should be set as `True` for both AutoConfig and AutoModel
+ config = AutoConfig.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ torch_dtype=torch_dtype,
+ trust_remote_code=True,
+ revision=revision,
+ )
+ # some models are loaded by AutoModel but not AutoModelForCausalLM,
+ # such as chatglm, chatglm2
+ try:
+ # google/flan-* models are based on an AutoModelForSeq2SeqLM.
+ if "T5Config" in str(type(config)):
+ model = AutoModelForSeq2SeqLM.from_config(
+ config, trust_remote_code=True
+ )
+ else:
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
+ except NameError:
+ model = AutoModel.from_config(config, trust_remote_code=True)
+ linear_weights = get_compressed_list(model)
+ if os.path.exists(model_path):
+ # `model_path` is a local folder
+ base_pattern = os.path.join(model_path, "pytorch_model*.bin")
+ else:
+ # `model_path` is a cached Hugging Face repo
+ # We don't necessarily need to download the model' repo again if there is a cache.
+ # So check the default huggingface cache first.
+ model_path_temp = os.path.join(
+ os.path.expanduser("~"),
+ ".cache/huggingface/hub",
+ "models--" + model_path.replace("/", "--"),
+ "snapshots/",
+ )
+ downloaded = False
+ if os.path.exists(model_path_temp):
+ temp_last_dir = os.listdir(model_path_temp)[-1]
+ model_path_temp = os.path.join(model_path_temp, temp_last_dir)
+ base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin")
+ files = glob.glob(base_pattern)
+ if len(files) > 0:
+ downloaded = True
+
+ if downloaded:
+ model_path = model_path_temp
+ else:
+ model_path = snapshot_download(model_path, revision=revision)
+ base_pattern = os.path.join(model_path, "pytorch_model*.bin")
+
+ files = glob.glob(base_pattern)
+ if len(files) == 0:
+ raise ValueError(
+ f"Cannot find any model weight files. "
+ f"Please check your (cached) weight path: {model_path}"
+ )
+
+ compressed_state_dict = {}
+ for filename in tqdm(files):
+ tmp_state_dict = torch.load(filename, map_location=lambda storage, loc: storage)
+ for name in tmp_state_dict:
+ if name in linear_weights:
+ tensor = tmp_state_dict[name].to(device, dtype=torch_dtype)
+ compressed_state_dict[name] = compress(
+ tensor, default_compression_config
+ )
+ else:
+ compressed_state_dict[name] = tmp_state_dict[name].to(
+ device, dtype=torch_dtype
+ )
+ tmp_state_dict[name] = None
+ tensor = None
+ gc.collect()
+ torch.cuda.empty_cache()
+ if device == "xpu":
+ torch.xpu.empty_cache()
+ if device == "npu":
+ torch.npu.empty_cache()
+
+ for name in model.state_dict():
+ if name not in linear_weights:
+ set_module_tensor_to_device(
+ model, name, device, value=compressed_state_dict[name]
+ )
+ apply_compressed_weight(model, compressed_state_dict, device)
+
+ if torch_dtype == torch.float16:
+ model.half()
+ model.to(device)
+ model.eval()
+
+ return model, tokenizer
+
+
+def compress(tensor, config):
+ """Simulate group-wise quantization."""
+ if not config.enabled:
+ return tensor
+
+ group_size, num_bits, group_dim, symmetric = (
+ config.group_size,
+ config.num_bits,
+ config.group_dim,
+ config.symmetric,
+ )
+ assert num_bits <= 8
+
+ original_shape = tensor.shape
+ num_groups = (original_shape[group_dim] + group_size - 1) // group_size
+ new_shape = (
+ original_shape[:group_dim]
+ + (num_groups, group_size)
+ + original_shape[group_dim + 1 :]
+ )
+
+ # Pad
+ pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
+ if pad_len != 0:
+ pad_shape = (
+ original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :]
+ )
+ tensor = torch.cat(
+ [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
+ dim=group_dim,
+ )
+ data = tensor.view(new_shape)
+
+ # Quantize
+ if symmetric:
+ B = 2 ** (num_bits - 1) - 1
+ scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
+ data = data * scale
+ data = data.clamp_(-B, B).round_().to(torch.int8)
+ return data, scale, original_shape
+ else:
+ B = 2**num_bits - 1
+ mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
+ mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
+
+ scale = B / (mx - mn)
+ data = data - mn
+ data.mul_(scale)
+
+ data = data.clamp_(0, B).round_().to(torch.uint8)
+ return data, mn, scale, original_shape
+
+
+def decompress(packed_data, config):
+ """Simulate group-wise dequantization."""
+ if not config.enabled:
+ return packed_data
+
+ group_size, num_bits, group_dim, symmetric = (
+ config.group_size,
+ config.num_bits,
+ config.group_dim,
+ config.symmetric,
+ )
+
+ # Dequantize
+ if symmetric:
+ data, scale, original_shape = packed_data
+ data = data / scale
+ else:
+ data, mn, scale, original_shape = packed_data
+ data = data / scale
+ data.add_(mn)
+
+ # Unpad
+ pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
+ if pad_len:
+ padded_original_shape = (
+ original_shape[:group_dim]
+ + (original_shape[group_dim] + pad_len,)
+ + original_shape[group_dim + 1 :]
+ )
+ data = data.reshape(padded_original_shape)
+ indices = [slice(0, x) for x in original_shape]
+ return data[indices].contiguous()
+ else:
+ return data.view(original_shape)
diff --git a/fastchat/model/convert_fp16.py b/fastchat/model/convert_fp16.py
new file mode 100644
index 0000000000000000000000000000000000000000..efc40aa83bf3a85129a668387df86a41d925f13d
--- /dev/null
+++ b/fastchat/model/convert_fp16.py
@@ -0,0 +1,26 @@
+"""
+Usage:
+python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder
+"""
+import argparse
+
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import torch
+
+
+def convert_fp16(in_checkpoint, out_checkpoint):
+ tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(
+ in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ )
+ model.save_pretrained(out_checkpoint)
+ tokenizer.save_pretrained(out_checkpoint)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
+ parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
+ args = parser.parse_args()
+
+ convert_fp16(args.in_checkpoint, args.out_checkpoint)
diff --git a/fastchat/model/llama_condense_monkey_patch.py b/fastchat/model/llama_condense_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb45a8bb6addf8a8506c847060e23dc65ae27995
--- /dev/null
+++ b/fastchat/model/llama_condense_monkey_patch.py
@@ -0,0 +1,71 @@
+# Code adapted from https://huggingface.co./kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py
+
+from functools import partial
+
+import torch
+import transformers
+import transformers.models.llama.modeling_llama
+
+
+class CondenseRotaryEmbedding(torch.nn.Module):
+ def __init__(
+ self, dim, ratio, max_position_embeddings=2048, base=10000, device=None
+ ):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.ratio = ratio
+ max_position_embeddings *= ratio
+ self.max_seq_len_cached = max_position_embeddings
+ # print(f"Monkey Patching condense ratio {ratio}")
+ t = (
+ torch.arange(
+ self.max_seq_len_cached,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype,
+ )
+ / ratio
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ dtype = torch.get_default_dtype()
+ self.register_buffer(
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
+ )
+ self.register_buffer(
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
+ )
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = (
+ torch.arange(
+ self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
+ )
+ / self.ratio
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.register_buffer(
+ "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False
+ )
+ self.register_buffer(
+ "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False
+ )
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ )
+
+
+def replace_llama_with_condense(ratio):
+ transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial(
+ CondenseRotaryEmbedding, ratio=ratio
+ )
diff --git a/fastchat/model/make_delta.py b/fastchat/model/make_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..480ba8f1a2cb067d69df174ee7d00e5072ee5164
--- /dev/null
+++ b/fastchat/model/make_delta.py
@@ -0,0 +1,48 @@
+"""
+Make the delta weights by subtracting base weights.
+
+Usage:
+python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+
+def make_delta(base_model_path, target_model_path, delta_path):
+ print(f"Loading the base model from {base_model_path}")
+ base = AutoModelForCausalLM.from_pretrained(
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ )
+
+ print(f"Loading the target model from {target_model_path}")
+ target = AutoModelForCausalLM.from_pretrained(
+ target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ )
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False)
+
+ print("Calculating the delta")
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
+ assert name in base.state_dict()
+ param.data -= base.state_dict()[name]
+
+ print(f"Saving the delta to {delta_path}")
+ if args.hub_repo_id:
+ kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id}
+ else:
+ kwargs = {}
+ target.save_pretrained(delta_path, **kwargs)
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+ parser.add_argument("--hub-repo-id", type=str)
+ args = parser.parse_args()
+
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d0c7eff859fc1cec3643ab74b0c3afc38ee5d10
--- /dev/null
+++ b/fastchat/model/model_adapter.py
@@ -0,0 +1,1970 @@
+"""Model adapter registration."""
+
+import math
+import os
+import re
+import sys
+from typing import Dict, List, Optional
+import warnings
+
+if sys.version_info >= (3, 9):
+ from functools import cache
+else:
+ from functools import lru_cache as cache
+
+import accelerate
+import psutil
+import torch
+from transformers import (
+ AutoConfig,
+ AutoModel,
+ AutoModelForCausalLM,
+ AutoModelForSeq2SeqLM,
+ AutoTokenizer,
+ LlamaTokenizer,
+ LlamaForCausalLM,
+ T5Tokenizer,
+)
+
+from fastchat.constants import CPU_ISA
+from fastchat.conversation import Conversation, get_conv_template
+from fastchat.model.compression import load_compress_model
+from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
+from fastchat.model.model_chatglm import generate_stream_chatglm
+from fastchat.model.model_codet5p import generate_stream_codet5p
+from fastchat.model.model_falcon import generate_stream_falcon
+from fastchat.model.model_exllama import generate_stream_exllama
+from fastchat.model.model_xfastertransformer import generate_stream_xft
+from fastchat.model.monkey_patch_non_inplace import (
+ replace_llama_attn_with_non_inplace_operations,
+)
+from fastchat.modules.awq import AWQConfig, load_awq_quantized
+from fastchat.modules.exllama import ExllamaConfig, load_exllama_model
+from fastchat.modules.xfastertransformer import load_xft_model, XftConfig
+from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
+from fastchat.utils import get_gpu_memory
+
+# Check an environment variable to check if we should be sharing Peft model
+# weights. When false we treat all Peft models as separate.
+peft_share_base_weights = (
+ os.environ.get("PEFT_SHARE_BASE_WEIGHTS", "false").lower() == "true"
+)
+
+ANTHROPIC_MODEL_LIST = (
+ "claude-1",
+ "claude-2",
+ "claude-instant-1",
+)
+
+
+class BaseModelAdapter:
+ """The base and the default model adapter."""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return True
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=self.use_fast_tokenizer,
+ revision=revision,
+ trust_remote_code=True,
+ )
+ except TypeError:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=False, revision=revision, trust_remote_code=True
+ )
+ try:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ use_flash_attention_2=True,
+
+ **from_pretrained_kwargs,
+ )
+ except: # NameError:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ use_flash_attention_2=False,
+
+ **from_pretrained_kwargs,
+ )
+ # model = AutoModel.from_pretrained(
+ # model_path,
+ # low_cpu_mem_usage=True,
+ # trust_remote_code=True,
+ # **from_pretrained_kwargs,
+ # )
+ return model, tokenizer
+
+ def load_compress_model(self, model_path, device, torch_dtype, revision="main"):
+ return load_compress_model(
+ model_path,
+ device,
+ torch_dtype,
+ use_fast=self.use_fast_tokenizer,
+ revision=revision,
+ )
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ if 'megrez' in model_path.lower():
+ model_path = 'megrez'
+ elif 'minicpm' in model_path.lower():
+ model_path = "minicpm"
+ return get_conv_template(model_path.lower())
+
+
+# A global registry for all model adapters
+# TODO (lmzheng): make it a priority queue.
+model_adapters: List[BaseModelAdapter] = []
+
+
+def register_model_adapter(cls):
+ """Register a model adapter."""
+ model_adapters.append(cls())
+
+
+@cache
+def get_model_adapter(model_path: str, model_name: str = None) -> BaseModelAdapter:
+ """Get a model adapter for a model_path."""
+ model_path_basename = os.path.basename(os.path.normpath(model_path)) if not model_name else model_name
+ # Try the basename of model_path at first
+ for adapter in model_adapters:
+ if adapter.match(model_path_basename) and type(adapter) != BaseModelAdapter:
+ print(f"Matching model adapter: {adapter}")
+ return adapter
+
+ model_path = model_path if not model_name else model_name
+ # Then try the full path
+ for adapter in model_adapters:
+ if adapter.match(model_path):
+ print(f"Using model adapter: {adapter}")
+ return adapter
+
+ raise ValueError(f"No valid model adapter for {model_path}")
+
+
+def raise_warning_for_incompatible_cpu_offloading_configuration(
+ device: str, load_8bit: bool, cpu_offloading: bool
+):
+ if cpu_offloading:
+ if not load_8bit:
+ warnings.warn(
+ "The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
+ "Use '--load-8bit' to enable 8-bit-quantization\n"
+ "Continuing without cpu-offloading enabled\n"
+ )
+ return False
+ if not "linux" in sys.platform:
+ warnings.warn(
+ "CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n"
+ "Continuing without cpu-offloading enabled\n"
+ )
+ return False
+ if device != "cuda":
+ warnings.warn(
+ "CPU-offloading is only enabled when using CUDA-devices\n"
+ "Continuing without cpu-offloading enabled\n"
+ )
+ return False
+ return cpu_offloading
+
+
+def load_model(
+ model_path: str,
+ device: str = "cuda",
+ num_gpus: int = 1,
+ max_gpu_memory: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ load_8bit: bool = False,
+ cpu_offloading: bool = False,
+ gptq_config: Optional[GptqConfig] = None,
+ awq_config: Optional[AWQConfig] = None,
+ exllama_config: Optional[ExllamaConfig] = None,
+ xft_config: Optional[XftConfig] = None,
+ revision: str = "main",
+ debug: bool = False,
+ model_name: str = None,
+):
+ """Load a model from Hugging Face."""
+ # get model adapter
+ adapter = get_model_adapter(model_path, model_name)
+
+ # Handle device mapping
+ cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
+ device, load_8bit, cpu_offloading
+ )
+ if device == "cpu":
+ # kwargs = {"torch_dtype": torch.float32}
+ kwargs = {"torch_dtype": torch.float16}
+ if CPU_ISA in ["avx512_bf16", "amx"]:
+ try:
+ import intel_extension_for_pytorch as ipex
+
+ kwargs = {"torch_dtype": torch.bfloat16}
+ except ImportError:
+ warnings.warn(
+ "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference"
+ )
+ elif device == "cuda":
+ # kwargs = {"torch_dtype": torch.float16}
+ kwargs = {"torch_dtype": torch.bfloat16}
+ if num_gpus != 1:
+ kwargs["device_map"] = "auto"
+ if max_gpu_memory is None:
+ kwargs[
+ "device_map"
+ ] = "sequential" # This is important for not the same VRAM sizes
+ available_gpu_memory = get_gpu_memory(num_gpus)
+ kwargs["max_memory"] = {
+ i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
+ for i in range(num_gpus)
+ }
+ else:
+ kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
+ elif device == "mps":
+ kwargs = {"torch_dtype": torch.float16}
+ # Avoid bugs in mps backend by not using in-place operations.
+ replace_llama_attn_with_non_inplace_operations()
+ elif device == "xpu":
+ kwargs = {"torch_dtype": torch.bfloat16}
+ # Try to load ipex, while it looks unused, it links into torch for xpu support
+ try:
+ import intel_extension_for_pytorch as ipex
+ except ImportError:
+ warnings.warn(
+ "Intel Extension for PyTorch is not installed, but is required for xpu inference."
+ )
+ elif device == "npu":
+ kwargs = {"torch_dtype": torch.float16}
+ # Try to load ipex, while it looks unused, it links into torch for xpu support
+ try:
+ import torch_npu
+ except ImportError:
+ warnings.warn("Ascend Extension for PyTorch is not installed.")
+ else:
+ raise ValueError(f"Invalid device: {device}")
+
+ if cpu_offloading:
+ # raises an error on incompatible platforms
+ from transformers import BitsAndBytesConfig
+
+ if "max_memory" in kwargs:
+ kwargs["max_memory"]["cpu"] = (
+ str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
+ )
+ kwargs["quantization_config"] = BitsAndBytesConfig(
+ load_in_8bit_fp32_cpu_offload=cpu_offloading
+ )
+ kwargs["load_in_8bit"] = load_8bit
+ elif load_8bit:
+ if num_gpus != 1:
+ warnings.warn(
+ "8-bit quantization is not supported for multi-gpu inference."
+ )
+ else:
+ model, tokenizer = adapter.load_compress_model(
+ model_path=model_path,
+ device=device,
+ torch_dtype=kwargs["torch_dtype"],
+ revision=revision,
+ )
+ if debug:
+ print(model)
+ return model, tokenizer
+ elif awq_config and awq_config.wbits < 16:
+ assert (
+ awq_config.wbits == 4
+ ), "Currently we only support 4-bit inference for AWQ."
+ model, tokenizer = load_awq_quantized(model_path, awq_config, device)
+ if num_gpus != 1:
+ device_map = accelerate.infer_auto_device_map(
+ model,
+ max_memory=kwargs["max_memory"],
+ no_split_module_classes=[
+ "OPTDecoderLayer",
+ "LlamaDecoderLayer",
+ "BloomBlock",
+ "MPTBlock",
+ "DecoderLayer",
+ ],
+ )
+ model = accelerate.dispatch_model(
+ model, device_map=device_map, offload_buffers=True
+ )
+ else:
+ model.to(device)
+ return model, tokenizer
+ elif gptq_config and gptq_config.wbits < 16:
+ model, tokenizer = load_gptq_quantized(model_path, gptq_config)
+ if num_gpus != 1:
+ device_map = accelerate.infer_auto_device_map(
+ model,
+ max_memory=kwargs["max_memory"],
+ no_split_module_classes=["LlamaDecoderLayer"],
+ )
+ model = accelerate.dispatch_model(
+ model, device_map=device_map, offload_buffers=True
+ )
+ else:
+ model.to(device)
+ return model, tokenizer
+ elif exllama_config:
+ model, tokenizer = load_exllama_model(model_path, exllama_config)
+ return model, tokenizer
+ elif xft_config:
+ model, tokenizer = load_xft_model(model_path, xft_config)
+ return model, tokenizer
+ kwargs["revision"] = revision
+
+ if dtype is not None: # Overwrite dtype if it is provided in the arguments.
+ kwargs["torch_dtype"] = dtype
+
+ # Load model
+ model, tokenizer = adapter.load_model(model_path, kwargs)
+
+ if (
+ device == "cpu"
+ and kwargs["torch_dtype"] is torch.bfloat16
+ and CPU_ISA is not None
+ ):
+ model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
+
+ if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in (
+ "mps",
+ "xpu",
+ "npu",
+ ):
+ model.to(device)
+
+ if device == "xpu":
+ model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True)
+
+ if debug:
+ print(model)
+
+ return model, tokenizer
+
+
+def get_conversation_template(model_path: str) -> Conversation:
+ """Get the default conversation template."""
+ adapter = get_model_adapter(model_path)
+ return adapter.get_default_conv_template(model_path)
+
+
+def get_generate_stream_function(model: torch.nn.Module, model_path: str):
+ """Get the generate_stream function for inference."""
+ from fastchat.serve.inference import generate_stream
+
+ model_type = str(type(model)).lower()
+ is_chatglm = "chatglm" in model_type
+ is_falcon = "rwforcausallm" in model_type
+ is_codet5p = "codet5p" in model_type
+ is_peft = "peft" in model_type
+ is_exllama = "exllama" in model_type
+ is_xft = "xft" in model_type
+
+ if is_chatglm:
+ return generate_stream_chatglm
+ elif is_falcon:
+ return generate_stream_falcon
+ elif is_codet5p:
+ return generate_stream_codet5p
+ elif is_exllama:
+ return generate_stream_exllama
+ elif is_xft:
+ return generate_stream_xft
+
+ elif peft_share_base_weights and is_peft:
+ # Return a curried stream function that loads the right adapter
+ # according to the model_name available in this context. This ensures
+ # the right weights are available.
+ @torch.inference_mode()
+ def generate_stream_peft(
+ model,
+ tokenizer,
+ params: Dict,
+ device: str,
+ context_len: int,
+ stream_interval: int = 2,
+ judge_sent_end: bool = False,
+ ):
+ model.set_adapter(model_path)
+ for x in generate_stream(
+ model,
+ tokenizer,
+ params,
+ device,
+ context_len,
+ stream_interval,
+ judge_sent_end,
+ ):
+ yield x
+
+ return generate_stream_peft
+ else:
+ return generate_stream
+
+
+def add_model_args(parser):
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default="lmsys/vicuna-7b-v1.5",
+ help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default="main",
+ help="Hugging Face Hub model revision identifier",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ choices=["cpu", "cuda", "mps", "xpu", "npu"],
+ default="cuda",
+ help="The device type",
+ )
+ parser.add_argument(
+ "--gpus",
+ type=str,
+ default=None,
+ help="A single GPU like 1 or multiple GPUs like 0,2",
+ )
+ parser.add_argument("--num-gpus", type=int, default=1)
+ parser.add_argument(
+ "--max-gpu-memory",
+ type=str,
+ help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ choices=["float32", "float16", "bfloat16"],
+ help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
+ default=None,
+ )
+ parser.add_argument(
+ "--load-8bit", action="store_true", help="Use 8-bit quantization"
+ )
+ parser.add_argument(
+ "--cpu-offloading",
+ action="store_true",
+ help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
+ )
+ parser.add_argument(
+ "--gptq-ckpt",
+ type=str,
+ default=None,
+ help="Used for GPTQ. The path to the local GPTQ checkpoint.",
+ )
+ parser.add_argument(
+ "--gptq-wbits",
+ type=int,
+ default=16,
+ choices=[2, 3, 4, 8, 16],
+ help="Used for GPTQ. #bits to use for quantization",
+ )
+ parser.add_argument(
+ "--gptq-groupsize",
+ type=int,
+ default=-1,
+ help="Used for GPTQ. Groupsize to use for quantization; default uses full row.",
+ )
+ parser.add_argument(
+ "--gptq-act-order",
+ action="store_true",
+ help="Used for GPTQ. Whether to apply the activation order GPTQ heuristic",
+ )
+ parser.add_argument(
+ "--awq-ckpt",
+ type=str,
+ default=None,
+ help="Used for AWQ. Load quantized model. The path to the local AWQ checkpoint.",
+ )
+ parser.add_argument(
+ "--awq-wbits",
+ type=int,
+ default=16,
+ choices=[4, 16],
+ help="Used for AWQ. #bits to use for AWQ quantization",
+ )
+ parser.add_argument(
+ "--awq-groupsize",
+ type=int,
+ default=-1,
+ help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.",
+ )
+ parser.add_argument(
+ "--enable-exllama",
+ action="store_true",
+ help="Used for exllamabv2. Enable exllamaV2 inference framework.",
+ )
+ parser.add_argument(
+ "--exllama-max-seq-len",
+ type=int,
+ default=4096,
+ help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.",
+ )
+ parser.add_argument(
+ "--exllama-gpu-split",
+ type=str,
+ default=None,
+ help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7",
+ )
+ parser.add_argument(
+ "--enable-xft",
+ action="store_true",
+ help="Used for xFasterTransformer Enable xFasterTransformer inference framework.",
+ )
+ parser.add_argument(
+ "--xft-max-seq-len",
+ type=int,
+ default=4096,
+ help="Used for xFasterTransformer. Max sequence length to use for xFasterTransformer framework; default 4096 sequence length.",
+ )
+ parser.add_argument(
+ "--xft-dtype",
+ type=str,
+ choices=["fp16", "bf16", "int8", "bf16_fp16", "bf16_int8"],
+ help="Override the default dtype. If not set, it will use bfloat16 for first token and float16 next tokens on CPU.",
+ default=None,
+ )
+
+
+def remove_parent_directory_name(model_path):
+ """Remove parent directory name."""
+ if model_path[-1] == "/":
+ model_path = model_path[:-1]
+ return model_path.split("/")[-1]
+
+
+peft_model_cache = {}
+
+
+class PeftModelAdapter:
+ """Loads any "peft" model and it's base model."""
+
+ def match(self, model_path: str):
+ """Accepts any model path with "peft" in the name"""
+ if os.path.exists(os.path.join(model_path, "adapter_config.json")):
+ return True
+ return "peft" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ """Loads the base model then the (peft) adapter weights"""
+ from peft import PeftConfig, PeftModel
+
+ config = PeftConfig.from_pretrained(model_path)
+ base_model_path = config.base_model_name_or_path
+ if "peft" in base_model_path:
+ raise ValueError(
+ f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}"
+ )
+
+ # Basic proof of concept for loading peft adapters that share the base
+ # weights. This is pretty messy because Peft re-writes the underlying
+ # base model and internally stores a map of adapter layers.
+ # So, to make this work we:
+ # 1. Cache the first peft model loaded for a given base models.
+ # 2. Call `load_model` for any follow on Peft models.
+ # 3. Make sure we load the adapters by the model_path. Why? This is
+ # what's accessible during inference time.
+ # 4. In get_generate_stream_function, make sure we load the right
+ # adapter before doing inference. This *should* be safe when calls
+ # are blocked the same semaphore.
+ if peft_share_base_weights:
+ if base_model_path in peft_model_cache:
+ model, tokenizer = peft_model_cache[base_model_path]
+ # Super important: make sure we use model_path as the
+ # `adapter_name`.
+ model.load_adapter(model_path, adapter_name=model_path)
+ else:
+ base_adapter = get_model_adapter(base_model_path)
+ base_model, tokenizer = base_adapter.load_model(
+ base_model_path, from_pretrained_kwargs
+ )
+ # Super important: make sure we use model_path as the
+ # `adapter_name`.
+ model = PeftModel.from_pretrained(
+ base_model, model_path, adapter_name=model_path
+ )
+ peft_model_cache[base_model_path] = (model, tokenizer)
+ return model, tokenizer
+
+ # In the normal case, load up the base model weights again.
+ base_adapter = get_model_adapter(base_model_path)
+ base_model, tokenizer = base_adapter.load_model(
+ base_model_path, from_pretrained_kwargs
+ )
+ model = PeftModel.from_pretrained(base_model, model_path)
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ """Uses the conv template of the base model"""
+ from peft import PeftConfig, PeftModel
+
+ config = PeftConfig.from_pretrained(model_path)
+ if "peft" in config.base_model_name_or_path:
+ raise ValueError(
+ f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}"
+ )
+ base_model_path = config.base_model_name_or_path
+ base_adapter = get_model_adapter(base_model_path)
+ return base_adapter.get_default_conv_template(config.base_model_name_or_path)
+
+
+
+class DeepseekChatAdapter(BaseModelAdapter):
+ """The model adapter for deepseek-ai's chat models"""
+
+ # Note: that this model will require tokenizer version >= 0.13.3 because the tokenizer class is LlamaTokenizerFast
+
+ def match(self, model_path: str):
+ return "deepseek" in model_path.lower() and "chat" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("deepseek")
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ device_map="sequential",
+ torch_dtype=torch.bfloat16,
+ max_memory=from_pretrained_kwargs['max_memory'],
+ attn_implementation="flash_attention_2"#"eager"
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, utrust_remote_code=True
+ )
+
+ return model, tokenizer
+
+
+class VicunaAdapter(BaseModelAdapter):
+ "Model adapter for Vicuna models (e.g., lmsys/vicuna-7b-v1.5)" ""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "vicuna" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=self.use_fast_tokenizer, revision=revision
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ use_flash_attention_2=True,
+ **from_pretrained_kwargs,
+ )
+ self.raise_warning_for_old_weights(model)
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ if "v0" in remove_parent_directory_name(model_path):
+ return get_conv_template("one_shot")
+ return get_conv_template("vicuna_v1.1")
+
+ def raise_warning_for_old_weights(self, model):
+ if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000:
+ warnings.warn(
+ "\nYou are probably using the old Vicuna-v0 model, "
+ "which will generate unexpected results with the "
+ "current fastchat.\nYou can try one of the following methods:\n"
+ "1. Upgrade your weights to the new Vicuna-v1.3: https://github.com/lm-sys/FastChat#vicuna-weights.\n"
+ "2. Use the old conversation template by `python3 -m fastchat.serve.cli --model-path /path/to/vicuna-v0 --conv-template one_shot`\n"
+ "3. Downgrade fschat to fschat==0.1.10 (Not recommended).\n"
+ )
+
+
+class AiroborosAdapter(BaseModelAdapter):
+ """The model adapter for jondurbin/airoboros-*"""
+
+ def match(self, model_path: str):
+ if re.search(r"airoboros|spicyboros", model_path, re.I):
+ return True
+ return False
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ if "-3." in model_path or "-3p" in model_path:
+ return get_conv_template("airoboros_v3")
+ if "spicyboros" in model_path or re.search(r"-(2\.[2-9]+)", model_path):
+ return get_conv_template("airoboros_v2")
+ return get_conv_template("airoboros_v1")
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ if "mpt" not in model_path.lower():
+ return super().load_model(model_path, from_pretrained_kwargs)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ max_seq_len=8192,
+ **from_pretrained_kwargs,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, use_fast=True
+ )
+ return model, tokenizer
+
+class Zhinao360Adapter(BaseModelAdapter):
+ def match(self, model_path: str):
+ return "360zhinao" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ trust_remote_code=True)
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True)
+ from transformers import GenerationConfig
+ generation_config = GenerationConfig.from_pretrained(
+ model_path,
+ trust_remote_code=True)
+
+ return model, tokenizer, generation_config
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("360zhinao")
+
+class LongChatAdapter(BaseModelAdapter):
+ "Model adapter for LongChat models (e.g., lmsys/longchat-7b-16k)."
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "longchat" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+
+ # Apply monkey patch, TODO(Dacheng): Add flash attention support
+ config = AutoConfig.from_pretrained(model_path, revision=revision)
+ replace_llama_with_condense(config.rope_scaling["factor"])
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=self.use_fast_tokenizer, revision=revision
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("vicuna_v1.1")
+
+
+class GoogleT5Adapter(BaseModelAdapter):
+ """The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2"""
+
+ def match(self, model_path: str):
+ return any(
+ model_str in model_path.lower()
+ for model_str in ["flan-", "fastchat-t5", "codet5p"]
+ )
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision)
+ model = AutoModelForSeq2SeqLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ **from_pretrained_kwargs,
+ )
+ return model, tokenizer
+
+
+class KoalaAdapter(BaseModelAdapter):
+ """The model adapter for Koala"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "koala" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("koala_v1")
+
+
+class AlpacaAdapter(BaseModelAdapter):
+ """The model adapter for Alpaca"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "alpaca" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("alpaca")
+
+
+class ChatGLMAdapter(BaseModelAdapter):
+ """The model adapter for THUDM/chatglm-6b, THUDM/chatglm2-6b"""
+
+ def match(self, model_path: str):
+ return "chatglm" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ if "chatglm3" in model_path.lower():
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ encode_special_tokens=True,
+ trust_remote_code=True,
+ revision=revision,
+ )
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ model = AutoModel.from_pretrained(
+ model_path, trust_remote_code=True, **from_pretrained_kwargs
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ model_path = model_path.lower()
+ if "chatglm2" in model_path.lower():
+ return get_conv_template("chatglm2")
+ if "chatglm3" in model_path.lower():
+ return get_conv_template("chatglm3")
+ return get_conv_template("chatglm")
+
+
+class CodeGeexAdapter(BaseModelAdapter):
+ """The model adapter for THUDM/codegeex-6b, THUDM/codegeex2-6b"""
+
+ def match(self, model_path: str):
+ return "codegeex" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ model = AutoModel.from_pretrained(
+ model_path, trust_remote_code=True, **from_pretrained_kwargs
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("codegeex")
+
+
+class DollyV2Adapter(BaseModelAdapter):
+ """The model adapter for databricks/dolly-v2-12b"""
+
+ def match(self, model_path: str):
+ return "dolly-v2" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ )
+ # 50277 means "### End"
+ tokenizer.eos_token_id = 50277
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("dolly_v2")
+
+
+class OasstPythiaAdapter(BaseModelAdapter):
+ """The model adapter for OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"""
+
+ def match(self, model_path: str):
+ model_path = model_path.lower()
+ return "oasst" in model_path and "pythia" in model_path
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("oasst_pythia")
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
+
+class OasstLLaMAAdapter(BaseModelAdapter):
+ """The model adapter for OpenAssistant/oasst-sft-7-llama-30b"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ model_path = model_path.lower()
+ if "openassistant-sft-7-llama-30b-hf" in model_path:
+ return True
+ return "oasst" in model_path and "pythia" not in model_path
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("oasst_llama")
+
+
+class OpenChat35Adapter(BaseModelAdapter):
+ """The model adapter for OpenChat 3.5 (e.g. openchat/openchat_3.5)"""
+
+ def match(self, model_path: str):
+ return "openchat" in model_path.lower() and "3.5" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("openchat_3.5")
+
+
+class PythiaAdapter(BaseModelAdapter):
+ """The model adapter for any EleutherAI/pythia model"""
+
+ def match(self, model_path: str):
+ return "pythia" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
+
+class StableLMAdapter(BaseModelAdapter):
+ """The model adapter for StabilityAI/stablelm-tuned-alpha-7b"""
+
+ def match(self, model_path: str):
+ return "stablelm" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("stablelm")
+
+
+class MPTAdapter(BaseModelAdapter):
+ """The model adapter for MPT series (mosaicml/mpt-7b-chat, mosaicml/mpt-30b-chat)"""
+
+ def match(self, model_path: str):
+ model_path = model_path.lower()
+ return "mpt" in model_path and not "airoboros" in model_path
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ max_seq_len=8192,
+ **from_pretrained_kwargs,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ model_path = model_path.lower()
+ if "mpt-7b-chat" in model_path:
+ return get_conv_template("mpt-7b-chat")
+ elif "mpt-30b-chat" in model_path:
+ return get_conv_template("mpt-30b-chat")
+ elif "mpt-30b-instruct" in model_path:
+ return get_conv_template("mpt-30b-instruct")
+ else:
+ print(
+ "Warning: Loading base MPT model with `zero_shot` conversation configuration. "
+ "If this is not desired, inspect model configurations and names."
+ )
+ return get_conv_template("zero_shot")
+
+
+class BaizeAdapter(BaseModelAdapter):
+ """The model adapter for project-baize/baize-v2-7b"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "baize" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("baize")
+
+
+class RwkvAdapter(BaseModelAdapter):
+ """The model adapter for BlinkDL/RWKV-4-Raven"""
+
+ def match(self, model_path: str):
+ return "rwkv-4" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ from fastchat.model.rwkv_model import RwkvModel
+
+ model = RwkvModel(model_path)
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ "EleutherAI/pythia-160m", revision=revision
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("rwkv")
+
+
+class OpenBuddyAdapter(BaseModelAdapter):
+ """The model adapter for OpenBuddy/openbuddy-7b-v1.1-bf16-enc"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "openbuddy" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("openbuddy")
+
+
+class PhoenixAdapter(BaseModelAdapter):
+ """The model adapter for FreedomIntelligence/phoenix-inst-chat-7b"""
+
+ def match(self, model_path: str):
+ return "phoenix" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("phoenix")
+
+
+class ReaLMAdapter(BaseModelAdapter):
+ """The model adapter for FreedomIntelligence/ReaLM-7b"""
+
+ def match(self, model_path: str):
+ return "ReaLM" in model_path
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("ReaLM-7b-v1")
+
+
+class ChatGPTAdapter(BaseModelAdapter):
+ """The model adapter for ChatGPT"""
+
+ def match(self, model_path: str):
+ return model_path in (
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-1106",
+ "gpt-4",
+ "gpt-4-turbo",
+ )
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ raise NotImplementedError()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("chatgpt")
+
+
+class AzureOpenAIAdapter(BaseModelAdapter):
+ """The model adapter for Azure OpenAI"""
+
+ def match(self, model_path: str):
+ return model_path in ("azure-gpt-35-turbo", "azure-gpt-4")
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ raise NotImplementedError()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("chatgpt")
+
+
+class ClaudeAdapter(BaseModelAdapter):
+ """The model adapter for Claude"""
+
+ def match(self, model_path: str):
+ return model_path in ANTHROPIC_MODEL_LIST
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ raise NotImplementedError()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("claude")
+
+
+class BardAdapter(BaseModelAdapter):
+ """The model adapter for Bard"""
+
+ def match(self, model_path: str):
+ return model_path == "bard"
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ raise NotImplementedError()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("bard")
+
+
+class PaLM2Adapter(BaseModelAdapter):
+ """The model adapter for PaLM2"""
+
+ def match(self, model_path: str):
+ return model_path == "palm-2"
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ raise NotImplementedError()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("bard")
+
+
+class BiLLaAdapter(BaseModelAdapter):
+ """The model adapter for Neutralzz/BiLLa-7B-SFT"""
+
+ def match(self, model_path: str):
+ return "billa" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("billa")
+
+
+class RedPajamaINCITEAdapter(BaseModelAdapter):
+ """The model adapter for togethercomputer/RedPajama-INCITE-7B-Chat"""
+
+ def match(self, model_path: str):
+ return "redpajama-incite" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("redpajama-incite")
+
+
+class H2OGPTAdapter(BaseModelAdapter):
+ """The model adapter for h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "h2ogpt" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("h2ogpt")
+
+
+class RobinAdapter(BaseModelAdapter):
+ """The model adapter for LMFlow/Full-Robin-7b-v2"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "robin" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("Robin")
+
+
+class SnoozyAdapter(BaseModelAdapter):
+ """The model adapter for nomic-ai/gpt4all-13b-snoozy"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ model_path = model_path.lower()
+ return "gpt4all" in model_path and "snoozy" in model_path
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("snoozy")
+
+
+class WizardLMAdapter(BaseModelAdapter):
+ """The model adapter for WizardLM/WizardLM-13B-V1.0"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "wizardlm" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ model_path = model_path.lower()
+ if "13b" in model_path or "30b" in model_path or "70b" in model_path:
+ return get_conv_template("vicuna_v1.1")
+ else:
+ # TODO: use the recommended template for 7B
+ # (https://huggingface.co./WizardLM/WizardLM-13B-V1.0)
+ return get_conv_template("one_shot")
+
+
+class ManticoreAdapter(BaseModelAdapter):
+ """The model adapter for openaccess-ai-collective/manticore-13b-chat-pyg"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "manticore" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("manticore")
+
+
+class GuanacoAdapter(BaseModelAdapter):
+ """The model adapter for timdettmers/guanaco-33b-merged"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "guanaco" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=self.use_fast_tokenizer, revision=revision
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
+ )
+ # Fix a bug in tokenizer config
+ tokenizer.eos_token_id = model.config.eos_token_id
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("zero_shot")
+
+
+class ChangGPTAdapter(BaseModelAdapter):
+ """The model adapter for lcw99/polyglot-ko-12.8b-chang-instruct-chat"""
+
+ def match(self, model_path: str):
+ model_path = model_path.lower()
+ return "polyglot" in model_path and "chang" in model_path
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("polyglot_changgpt")
+
+
+class CamelAdapter(BaseModelAdapter):
+ """The model adapter for camel-ai/CAMEL-13B-Combined-Data"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "camel" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("vicuna_v1.1")
+
+
+class TuluAdapter(BaseModelAdapter):
+ """The model adapter for allenai/tulu-30b"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "tulu" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("tulu")
+
+
+class FalconAdapter(BaseModelAdapter):
+ """The model adapter for tiiuae/falcon-40b"""
+
+ def match(self, model_path: str):
+ return "falcon" in model_path.lower() and "chat" not in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ # Strongly suggest using bf16, which is recommended by the author of Falcon
+ tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ **from_pretrained_kwargs,
+ )
+ # In Falcon tokenizer config and special config there is not any pad token
+ # Setting `pad_token_id` to 9, which corresponds to special token '>>SUFFIX<<'
+ tokenizer.pad_token_id = 9
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("falcon")
+
+
+class FalconChatAdapter(BaseModelAdapter):
+ def match(self, model_path: str):
+ return "falcon" in model_path.lower() and "chat" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("falcon-chat")
+
+
+class TigerBotAdapter(BaseModelAdapter):
+ """The model adapter for TigerResearch/tigerbot-7b-sft"""
+
+ def match(self, model_path: str):
+ return "tigerbot" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ revision=revision,
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("tigerbot")
+
+
+class BaichuanAdapter(BaseModelAdapter):
+ """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-7B)"""
+
+ def match(self, model_path: str):
+ return "baichuan" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ # for Baichuan-13B-Chat
+ if "chat" in model_path.lower():
+ if "baichuan2" in model_path.lower():
+ return get_conv_template("baichuan2-chat")
+ return get_conv_template("baichuan-chat")
+ return get_conv_template("zero_shot")
+
+
+class XGenAdapter(BaseModelAdapter):
+ """The model adapter for Salesforce/xgen-7b"""
+
+ def match(self, model_path: str):
+ return "xgen" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ **from_pretrained_kwargs,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ model.config.eos_token_id = 50256
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("xgen")
+
+
+class NousHermesAdapter(BaseModelAdapter):
+ """The model adapter for NousResearch/Nous-Hermes-13b"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "nous-hermes" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("alpaca")
+
+
+class InternLMChatAdapter(BaseModelAdapter):
+ """The model adapter for internlm/internlm-chat-7b"""
+
+ def match(self, model_path: str):
+ return "internlm-chat" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ **from_pretrained_kwargs,
+ )
+ model = model.eval()
+ if "8k" in model_path.lower():
+ model.config.max_sequence_length = 8192
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("internlm-chat")
+
+
+class StarChatAdapter(BaseModelAdapter):
+ """The model adapter for HuggingFaceH4/starchat-beta"""
+
+ def match(self, model_path: str):
+ return "starchat" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("starchat")
+
+
+class MistralAdapter(BaseModelAdapter):
+ """The model adapter for Mistral AI models"""
+
+ def match(self, model_path: str):
+ return "mistral" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("mistral")
+
+
+class Llama2Adapter(BaseModelAdapter):
+ """The model adapter for Llama-2 (e.g., meta-llama/Llama-2-7b-hf)"""
+
+ def match(self, model_path: str):
+ return "llama-2" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("llama-2")
+
+
+class CuteGPTAdapter(BaseModelAdapter):
+ """The model adapter for CuteGPT"""
+
+ def match(self, model_path: str):
+ return "cutegpt" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
+ )
+ tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("")
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.eos_token_id
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("cutegpt")
+
+
+class OpenOrcaAdapter(BaseModelAdapter):
+ """Model adapter for Open-Orca models which may use different prompt templates
+ - (e.g. Open-Orca/OpenOrcaxOpenChat-Preview2-13B, Open-Orca/Mistral-7B-OpenOrca)
+ - `OpenOrcaxOpenChat-Preview2-13B` uses their "OpenChat Llama2 V1" prompt template.
+ - [Open-Orca/OpenOrcaxOpenChat-Preview2-13B #Prompt Template](https://huggingface.co./Open-Orca/OpenOrcaxOpenChat-Preview2-13B#prompt-template)
+ - `Mistral-7B-OpenOrca` uses the [OpenAI's Chat Markup Language (ChatML)](https://github.com/openai/openai-python/blob/main/chatml.md)
+ format, with <|im_start|> and <|im_end|> tokens added to support this.
+ - [Open-Orca/Mistral-7B-OpenOrca #Prompt Template](https://huggingface.co./Open-Orca/Mistral-7B-OpenOrca#prompt-template)
+ """
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return (
+ "mistral-7b-openorca" in model_path.lower()
+ or "openorca" in model_path.lower()
+ )
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=self.use_fast_tokenizer, revision=revision
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ ).eval()
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ if "mistral-7b-openorca" in model_path.lower():
+ return get_conv_template("mistral-7b-openorca")
+ return get_conv_template("open-orca")
+
+
+class WizardCoderAdapter(BaseModelAdapter):
+ """The model adapter for WizardCoder (e.g., WizardLM/WizardCoder-Python-34B-V1.0)"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "wizardcoder" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ # Same as Alpaca, see :
+ # https://github.com/nlpxucan/WizardLM/blob/main/WizardCoder/src/inference_wizardcoder.py#L60
+ return get_conv_template("alpaca")
+
+
+class QwenChatAdapter(BaseModelAdapter):
+ """The model adapter for Qwen/Qwen-7B-Chat
+ To run this model, you need to ensure additional flash attention installation:
+ ``` bash
+ git clone https://github.com/Dao-AILab/flash-attention
+ cd flash-attention && pip install .
+ pip install csrc/layer_norm
+ pip install csrc/rotary
+ ```
+
+ Since from 2.0, the following change happened
+ - `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
+ - `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
+ - `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
+ You may need to revise the code in: https://huggingface.co./Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L69
+ to from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
+ """
+
+ def match(self, model_path: str):
+ return "qwen" in model_path.lower()
+
+ def float_set(self, config, option):
+ config.bf16 = False
+ config.fp16 = False
+ config.fp32 = False
+
+ if option == "bf16":
+ config.bf16 = True
+ elif option == "fp16":
+ config.fp16 = True
+ elif option == "fp32":
+ config.fp32 = True
+ else:
+ print("Invalid option. Please choose one from 'bf16', 'fp16' and 'fp32'.")
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ from transformers.generation import GenerationConfig
+
+ revision = from_pretrained_kwargs.get("revision", "main")
+ config = AutoConfig.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ )
+ # NOTE: if you use the old version of model file, please remove the comments below
+ # config.use_flash_attn = False
+ self.float_set(config, "fp16")
+ generation_config = GenerationConfig.from_pretrained(
+ model_path, trust_remote_code=True
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ config=config,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ **from_pretrained_kwargs,
+ ).eval()
+ if hasattr(model.config, "use_dynamic_ntk") and model.config.use_dynamic_ntk:
+ model.config.max_sequence_length = 16384
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ tokenizer.eos_token_id = config.eos_token_id
+ tokenizer.bos_token_id = config.bos_token_id
+ tokenizer.pad_token_id = generation_config.pad_token_id
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.bos_token_id = tokenizer.bos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("qwen-7b-chat")
+
+
+class BGEAdapter(BaseModelAdapter):
+ """The model adapter for BGE (e.g., BAAI/bge-large-en-v1.5)"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "bge" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ model = AutoModel.from_pretrained(
+ model_path,
+ **from_pretrained_kwargs,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ if hasattr(model.config, "max_position_embeddings") and hasattr(
+ tokenizer, "model_max_length"
+ ):
+ model.config.max_sequence_length = min(
+ model.config.max_position_embeddings, tokenizer.model_max_length
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("one_shot")
+
+
+class E5Adapter(BaseModelAdapter):
+ """The model adapter for E5 (e.g., intfloat/e5-large-v2)"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "e5-" in model_path.lower() and 'megrez' not in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ model = AutoModel.from_pretrained(
+ model_path,
+ **from_pretrained_kwargs,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ if hasattr(model.config, "max_position_embeddings") and hasattr(
+ tokenizer, "model_max_length"
+ ):
+ model.config.max_sequence_length = min(
+ model.config.max_position_embeddings, tokenizer.model_max_length
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("one_shot")
+
+
+class AquilaChatAdapter(BaseModelAdapter):
+ """The model adapter for BAAI/Aquila
+
+ Now supports:
+ - BAAI/AquilaChat-7B
+ - BAAI/AquilaChat2-7B
+ - BAAI/AquilaChat2-34B
+ """
+
+ def match(self, model_path: str):
+ return "aquila" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ **from_pretrained_kwargs,
+ )
+ model = model.eval()
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, revision=revision
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ model_path = model_path.lower()
+ # See: https://huggingface.co./BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L347
+ if "aquilachat2" in model_path:
+ if "16k" in model_path:
+ return get_conv_template("aquila")
+ elif "34b" in model_path:
+ return get_conv_template("aquila-legacy")
+ else:
+ return get_conv_template("aquila-v1")
+ else:
+ return get_conv_template("aquila-chat")
+
+
+class Lamma2ChineseAdapter(BaseModelAdapter):
+ """The model adapter for FlagAlpha/LLama2-Chinese sft"""
+
+ def match(self, model_path: str):
+ return "llama2-chinese" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ revision=revision,
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("llama2-chinese")
+
+
+class VigogneAdapter(BaseModelAdapter):
+ """The model adapter for vigogne (e.g., bofenghuang/vigogne-2-7b-chat)"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return bool(re.search(r"vigogne|vigostral", model_path, re.I))
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=self.use_fast_tokenizer,
+ trust_remote_code=True,
+ revision=revision,
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ ).eval()
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ if "chat" in model_path.lower():
+ if "vigostral" in model_path.lower():
+ return get_conv_template("vigogne_chat_v3")
+ return get_conv_template("vigogne_chat_v2")
+ return get_conv_template("vigogne_instruct")
+
+
+class OpenLLaMaOpenInstructAdapter(BaseModelAdapter):
+ """The model adapter for OpenLLaMa-Open-Instruct (e.g., VMware/open-llama-7b-open-instruct)"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return (
+ "open-llama" in model_path.lower() and "open-instruct" in model_path.lower()
+ )
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=self.use_fast_tokenizer,
+ trust_remote_code=True,
+ revision=revision,
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ ).eval()
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("alpaca")
+
+
+class CodeLlamaAdapter(BaseModelAdapter):
+ """The model adapter for CodeLlama (e.g., codellama/CodeLlama-34b-hf)"""
+
+ def match(self, model_path: str):
+ return "codellama" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("llama-2")
+
+
+class PhindCodeLlamaAdapter(CodeLlamaAdapter):
+ """The model adapter for Phind-CodeLlama (e.g., Phind/Phind-CodeLlama-34B-v2)"""
+
+ def match(self, model_path: str):
+ return "phind-codellama-" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("phind")
+
+
+class Llama2ChangAdapter(Llama2Adapter):
+ """The model adapter for Llama2-ko-chang (e.g., lcw99/llama2-ko-chang-instruct-chat)"""
+
+ def match(self, model_path: str):
+ return "llama2-ko-chang" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("polyglot_changgpt")
+
+
+class ZephyrAdapter(BaseModelAdapter):
+ """The model adapter for Zephyr (e.g. HuggingFaceH4/zephyr-7b-alpha)"""
+
+ def match(self, model_path: str):
+ return "zephyr" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("zephyr")
+
+
+class XwinLMAdapter(BaseModelAdapter):
+ """The model adapter for Xwin-LM V0.1 and V0.2 series of models(e.g., Xwin-LM/Xwin-LM-70B-V0.1)"""
+
+ # use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "xwin-lm" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("vicuna_v1.1")
+
+
+class LemurAdapter(BaseModelAdapter):
+ """The model adapter for OpenLemur/lemur-70b-chat-v1"""
+
+ use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return "lemur-70b-chat" in model_path.lower()
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("lemur-70b-chat")
+
+
+class PygmalionAdapter(BaseModelAdapter):
+ """The model adapter for Pygmalion/Metharme series of models(e.g., PygmalionAI/mythalion-13b)"""
+
+ # use_fast_tokenizer = False
+
+ def match(self, model_path: str):
+ return bool(
+ re.search(r"pygmalion|mythalion|metharme", model_path.lower(), re.I)
+ )
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("metharme")
+
+
+# Note: the registration order matters.
+# The one registered earlier has a higher matching priority.
+register_model_adapter(PeftModelAdapter)
+register_model_adapter(DeepseekChatAdapter)
+register_model_adapter(VicunaAdapter)
+register_model_adapter(AiroborosAdapter)
+register_model_adapter(LongChatAdapter)
+register_model_adapter(GoogleT5Adapter)
+register_model_adapter(KoalaAdapter)
+register_model_adapter(AlpacaAdapter)
+register_model_adapter(ChatGLMAdapter)
+register_model_adapter(CodeGeexAdapter)
+register_model_adapter(DollyV2Adapter)
+register_model_adapter(OasstPythiaAdapter)
+register_model_adapter(OasstLLaMAAdapter)
+register_model_adapter(OpenChat35Adapter)
+register_model_adapter(StableLMAdapter)
+register_model_adapter(BaizeAdapter)
+register_model_adapter(RwkvAdapter)
+register_model_adapter(OpenBuddyAdapter)
+register_model_adapter(PhoenixAdapter)
+register_model_adapter(BardAdapter)
+register_model_adapter(PaLM2Adapter)
+register_model_adapter(ChatGPTAdapter)
+register_model_adapter(AzureOpenAIAdapter)
+register_model_adapter(ClaudeAdapter)
+register_model_adapter(MPTAdapter)
+register_model_adapter(BiLLaAdapter)
+register_model_adapter(RedPajamaINCITEAdapter)
+register_model_adapter(H2OGPTAdapter)
+register_model_adapter(RobinAdapter)
+register_model_adapter(SnoozyAdapter)
+register_model_adapter(WizardLMAdapter)
+register_model_adapter(ManticoreAdapter)
+register_model_adapter(GuanacoAdapter)
+register_model_adapter(CamelAdapter)
+register_model_adapter(ChangGPTAdapter)
+register_model_adapter(TuluAdapter)
+register_model_adapter(FalconChatAdapter)
+register_model_adapter(FalconAdapter)
+register_model_adapter(TigerBotAdapter)
+register_model_adapter(BaichuanAdapter)
+register_model_adapter(XGenAdapter)
+register_model_adapter(NousHermesAdapter)
+register_model_adapter(PythiaAdapter)
+register_model_adapter(InternLMChatAdapter)
+register_model_adapter(StarChatAdapter)
+register_model_adapter(Llama2Adapter)
+register_model_adapter(CuteGPTAdapter)
+register_model_adapter(OpenOrcaAdapter)
+register_model_adapter(MistralAdapter)
+register_model_adapter(WizardCoderAdapter)
+register_model_adapter(QwenChatAdapter)
+register_model_adapter(AquilaChatAdapter)
+register_model_adapter(BGEAdapter)
+register_model_adapter(E5Adapter)
+register_model_adapter(Lamma2ChineseAdapter)
+register_model_adapter(VigogneAdapter)
+register_model_adapter(OpenLLaMaOpenInstructAdapter)
+register_model_adapter(ReaLMAdapter)
+register_model_adapter(PhindCodeLlamaAdapter)
+register_model_adapter(CodeLlamaAdapter)
+register_model_adapter(Llama2ChangAdapter)
+register_model_adapter(ZephyrAdapter)
+register_model_adapter(XwinLMAdapter)
+register_model_adapter(LemurAdapter)
+register_model_adapter(PygmalionAdapter)
+register_model_adapter(Zhinao360Adapter)
+
+# After all adapters, try the default base adapter.
+register_model_adapter(BaseModelAdapter)
diff --git a/fastchat/model/model_chatglm.py b/fastchat/model/model_chatglm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d4db62bc069a39d92104d6cdf401ccee1a68d0f
--- /dev/null
+++ b/fastchat/model/model_chatglm.py
@@ -0,0 +1,102 @@
+"""
+Inference code for ChatGLM.
+Adapted from https://huggingface.co./THUDM/chatglm-6b/blob/main/modeling_chatglm.py.
+"""
+import re
+
+import torch
+from transformers.generation.logits_process import LogitsProcessor
+
+
+class InvalidScoreLogitsProcessor(LogitsProcessor):
+ def __call__(
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
+ scores.zero_()
+ scores[..., 5] = 5e4
+ return scores
+
+
+invalid_score_processor = InvalidScoreLogitsProcessor()
+
+
+def process_response(response):
+ response = response.strip()
+ response = response.replace("[[训练时间]]", "2023年")
+ punkts = [
+ [",", ","],
+ ["!", "!"],
+ [":", ":"],
+ [";", ";"],
+ ["\?", "?"],
+ ]
+ for item in punkts:
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
+ return response
+
+
+@torch.inference_mode()
+def generate_stream_chatglm(
+ model,
+ tokenizer,
+ params,
+ device,
+ context_len=2048,
+ stream_interval=2,
+ judge_sent_end=False,
+):
+ prompt = params["prompt"]
+ temperature = float(params.get("temperature", 1.0))
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ max_new_tokens = int(params.get("max_new_tokens", 256))
+ echo = params.get("echo", True)
+
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
+ input_echo_len = len(inputs["input_ids"][0])
+
+ gen_kwargs = {
+ "max_length": max_new_tokens + input_echo_len,
+ "do_sample": True if temperature > 1e-5 else False,
+ "top_p": top_p,
+ "repetition_penalty": repetition_penalty,
+ "logits_processor": [invalid_score_processor],
+ }
+ if temperature > 1e-5:
+ gen_kwargs["temperature"] = temperature
+
+ total_len = 0
+ for total_ids in model.stream_generate(**inputs, **gen_kwargs):
+ total_ids = total_ids.tolist()[0]
+ total_len = len(total_ids)
+ if echo:
+ output_ids = total_ids
+ else:
+ output_ids = total_ids[input_echo_len:]
+ response = tokenizer.decode(output_ids)
+ response = process_response(response)
+
+ yield {
+ "text": response,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": total_len - input_echo_len,
+ "total_tokens": total_len,
+ },
+ "finish_reason": None,
+ }
+
+ # TODO: ChatGLM stop when it reach max length
+ # Only last stream result contains finish_reason, we set finish_reason as stop
+ ret = {
+ "text": response,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": total_len - input_echo_len,
+ "total_tokens": total_len,
+ },
+ "finish_reason": "stop",
+ }
+ yield ret
diff --git a/fastchat/model/model_codet5p.py b/fastchat/model/model_codet5p.py
new file mode 100644
index 0000000000000000000000000000000000000000..0984513c96931b6d48dfd17f3020fe5cebc3f911
--- /dev/null
+++ b/fastchat/model/model_codet5p.py
@@ -0,0 +1,108 @@
+import gc
+from threading import Thread
+import torch
+import transformers
+from transformers import (
+ GenerationConfig,
+ StoppingCriteria,
+ StoppingCriteriaList,
+ TextIteratorStreamer,
+)
+
+
+@torch.inference_mode()
+def generate_stream_codet5p(
+ model,
+ tokenizer,
+ params,
+ device,
+ context_len=2048,
+ stream_interval=2,
+ judge_sent_end=False,
+):
+ prompt = params["prompt"]
+ temperature = float(params.get("temperature", 1.0))
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ top_k = int(params.get("top_k", 50)) # -1 means disable
+ max_new_tokens = int(params.get("max_new_tokens", 1024))
+ stop_token_ids = params.get("stop_token_ids", None) or []
+ stop_token_ids.append(tokenizer.eos_token_id)
+
+ decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ streamer = TextIteratorStreamer(tokenizer, **decode_config)
+ encoding = tokenizer(prompt, return_tensors="pt").to(device)
+ input_ids = encoding.input_ids
+ encoding["decoder_input_ids"] = encoding["input_ids"].clone()
+ input_echo_len = len(input_ids)
+
+ generation_config = GenerationConfig(
+ max_new_tokens=max_new_tokens,
+ do_sample=temperature >= 1e-5,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ no_repeat_ngram_size=10,
+ top_p=top_p,
+ top_k=top_k,
+ eos_token_id=stop_token_ids,
+ )
+
+ class CodeBlockStopper(StoppingCriteria):
+ def __call__(
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
+ ) -> bool:
+ # Code-completion is open-end generation.
+ # We check \n\n to stop at end of a code block.
+ if list(input_ids[0][-2:]) == [628, 198]:
+ return True
+ return False
+
+ gen_kwargs = dict(
+ **encoding,
+ streamer=streamer,
+ generation_config=generation_config,
+ stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]),
+ )
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
+ thread.start()
+ i = 0
+ output = ""
+ for new_text in streamer:
+ i += 1
+ output += new_text
+ if i % stream_interval == 0 or i == max_new_tokens - 1:
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": None,
+ }
+ if i >= max_new_tokens:
+ break
+
+ if i >= max_new_tokens:
+ finish_reason = "length"
+ else:
+ finish_reason = "stop"
+
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": finish_reason,
+ }
+ thread.join()
+
+ # clean
+ gc.collect()
+ torch.cuda.empty_cache()
+ if device == "xpu":
+ torch.xpu.empty_cache()
+ if device == "npu":
+ torch.npu.empty_cache()
diff --git a/fastchat/model/model_exllama.py b/fastchat/model/model_exllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..306edab21a79658d22eb75f1da3eba1f830e4ae7
--- /dev/null
+++ b/fastchat/model/model_exllama.py
@@ -0,0 +1,77 @@
+import gc
+import sys
+from typing import Dict
+
+import torch
+
+
+def generate_stream_exllama(
+ model,
+ tokenizer,
+ params: Dict,
+ device: str,
+ context_len: int,
+ stream_interval: int = 2,
+ judge_sent_end: bool = False,
+):
+ try:
+ from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
+ except ImportError as e:
+ print(f"Error: Failed to load Exllamav2. {e}")
+ sys.exit(-1)
+
+ prompt = params["prompt"]
+
+ generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer)
+ settings = ExLlamaV2Sampler.Settings()
+
+ settings.temperature = float(params.get("temperature", 0.85))
+ settings.top_k = int(params.get("top_k", 50))
+ settings.top_p = float(params.get("top_p", 0.8))
+ settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15))
+ settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])
+
+ max_new_tokens = int(params.get("max_new_tokens", 256))
+
+ generator.set_stop_conditions(params.get("stop_token_ids", None) or [])
+ echo = bool(params.get("echo", True))
+
+ input_ids = generator.tokenizer.encode(prompt)
+ prompt_tokens = input_ids.shape[-1]
+ generator.begin_stream(input_ids, settings)
+
+ generated_tokens = 0
+ if echo:
+ output = prompt
+ else:
+ output = ""
+ while True:
+ chunk, eos, _ = generator.stream()
+ output += chunk
+ generated_tokens += 1
+ if generated_tokens == max_new_tokens:
+ finish_reason = "length"
+ break
+ elif eos:
+ finish_reason = "length"
+ break
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": generated_tokens,
+ "total_tokens": prompt_tokens + generated_tokens,
+ },
+ "finish_reason": None,
+ }
+
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": generated_tokens,
+ "total_tokens": prompt_tokens + generated_tokens,
+ },
+ "finish_reason": finish_reason,
+ }
+ gc.collect()
diff --git a/fastchat/model/model_falcon.py b/fastchat/model/model_falcon.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc8af8efa20bd29fb31cdd0a0bc039b30f4bf26e
--- /dev/null
+++ b/fastchat/model/model_falcon.py
@@ -0,0 +1,140 @@
+import gc
+from threading import Thread
+from typing import Iterable
+
+import torch
+import transformers
+from transformers import TextIteratorStreamer, GenerationConfig
+
+from fastchat.utils import is_partial_stop
+
+
+@torch.inference_mode()
+def generate_stream_falcon(
+ model,
+ tokenizer,
+ params,
+ device,
+ context_len=2048,
+ stream_interval=2,
+ judge_sent_end=False,
+):
+ prompt = params["prompt"]
+ len_prompt = len(prompt)
+ temperature = float(params.get("temperature", 1.0))
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ top_k = int(params.get("top_k", 50)) # -1 means disable
+ max_new_tokens = int(params.get("max_new_tokens", 256))
+ stop_str = params.get("stop", None)
+ echo = bool(params.get("echo", True))
+ stop_token_ids = params.get("stop_token_ids", None) or []
+ stop_token_ids.append(tokenizer.eos_token_id)
+
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+ input_ids = inputs["input_ids"]
+ attention_mask = inputs["attention_mask"]
+
+ max_src_len = context_len - max_new_tokens - 8
+
+ input_ids = input_ids[-max_src_len:] # truncate from the left
+ attention_mask = attention_mask[-max_src_len:] # truncate from the left
+ input_echo_len = len(input_ids)
+
+ decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
+
+ generation_config = GenerationConfig(
+ max_new_tokens=max_new_tokens,
+ do_sample=temperature >= 1e-5,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ no_repeat_ngram_size=10,
+ top_p=top_p,
+ top_k=top_k,
+ eos_token_id=stop_token_ids,
+ )
+
+ generation_kwargs = dict(
+ inputs=input_ids,
+ attention_mask=attention_mask,
+ streamer=streamer,
+ generation_config=generation_config,
+ )
+
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
+ thread.start()
+
+ if echo:
+ # means keep the prompt
+ output = prompt
+ else:
+ output = ""
+
+ for i, new_text in enumerate(streamer):
+ output += new_text
+ if i % stream_interval == 0:
+ if echo:
+ rfind_start = len_prompt
+ else:
+ rfind_start = 0
+
+ partially_stopped = False
+ if stop_str:
+ if isinstance(stop_str, str):
+ pos = output.rfind(stop_str, rfind_start)
+ if pos != -1:
+ output = output[:pos]
+ else:
+ partially_stopped = is_partial_stop(output, stop_str)
+ elif isinstance(stop_str, Iterable):
+ for each_stop in stop_str:
+ pos = output.rfind(each_stop, rfind_start)
+ if pos != -1:
+ output = output[:pos]
+ break
+ else:
+ partially_stopped = is_partial_stop(output, each_stop)
+ if partially_stopped:
+ break
+ else:
+ raise ValueError("Invalid stop field type.")
+
+ # prevent yielding partial stop sequence
+ if not partially_stopped:
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": None,
+ }
+ output = output.strip()
+
+ # finish stream event, which contains finish reason
+ if i == max_new_tokens - 1:
+ finish_reason = "length"
+ elif partially_stopped:
+ finish_reason = None
+ else:
+ finish_reason = "stop"
+
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": finish_reason,
+ }
+
+ # clean
+ gc.collect()
+ torch.cuda.empty_cache()
+ if device == "xpu":
+ torch.xpu.empty_cache()
+ if device == "npu":
+ torch.npu.empty_cache()
diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..da08c2e26a7f0f885b890589ccd23425f5633227
--- /dev/null
+++ b/fastchat/model/model_registry.py
@@ -0,0 +1,387 @@
+"""Additional information of the models."""
+from collections import namedtuple
+from typing import List
+
+
+ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"])
+
+
+model_info = {}
+
+
+def register_model_info(
+ full_names: List[str], simple_name: str, link: str, description: str
+):
+ info = ModelInfo(simple_name, link, description)
+
+ for full_name in full_names:
+ model_info[full_name] = info
+
+
+def get_model_info(name: str) -> ModelInfo:
+ if name in model_info:
+ return model_info[name]
+ else:
+ # To fix this, please use `register_model_info` to register your model
+ return ModelInfo(
+ name, "", "Register the description at fastchat/model/model_registry.py"
+ )
+
+
+register_model_info(
+ ["gpt-3.5-turbo"],
+ "GPT-3.5",
+ "https://openai.com/blog/chatgpt",
+ "GPT-3.5 by OpenAI",
+)
+register_model_info(
+ ["gpt-3.5-turbo-1106"],
+ "GPT-3.5-Turbo-1106",
+ "https://platform.openai.com/docs/models/gpt-3-5",
+ "GPT-3.5-Turbo-1106 by OpenAI",
+)
+register_model_info(
+ ["gpt-4"], "GPT-4", "https://openai.com/research/gpt-4", "ChatGPT-4 by OpenAI"
+)
+register_model_info(
+ ["gpt-4-turbo"],
+ "GPT-4-Turbo",
+ "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
+ "GPT-4-Turbo by OpenAI",
+)
+register_model_info(
+ ["claude-2"],
+ "Claude",
+ "https://www.anthropic.com/index/claude-2",
+ "Claude 2 by Anthropic",
+)
+register_model_info(
+ ["claude-1"],
+ "Claude",
+ "https://www.anthropic.com/index/introducing-claude",
+ "Claude by Anthropic",
+)
+register_model_info(
+ ["claude-instant-1"],
+ "Claude Instant",
+ "https://www.anthropic.com/index/introducing-claude",
+ "Claude Instant by Anthropic",
+)
+register_model_info(
+ ["palm-2"],
+ "PaLM 2 Chat",
+ "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023",
+ "PaLM 2 for Chat (chat-bison@001) by Google",
+)
+register_model_info(
+ [
+ "vicuna-33b",
+ "vicuna-33b-v1.3",
+ "vicuna-13b",
+ "vicuna-13b-v1.3",
+ "vicuna-7b",
+ "vicuna-7b-v1.3",
+ ],
+ "Vicuna",
+ "https://lmsys.org/blog/2023-03-30-vicuna/",
+ "a chat assistant fine-tuned on user-shared conversations by LMSYS",
+)
+register_model_info(
+ ["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"],
+ "Llama 2",
+ "https://ai.meta.com/llama/",
+ "open foundation and fine-tuned chat models by Meta",
+)
+register_model_info(
+ ["mistral-7b-instruct"],
+ "Mistral",
+ "https://huggingface.co./mistralai/Mistral-7B-Instruct-v0.1",
+ "a large language model by Mistral AI team",
+)
+register_model_info(
+ ["zephyr-7b-beta", "zephyr-7b-alpha"],
+ "Zephyr",
+ "https://huggingface.co./HuggingFaceH4/zephyr-7b-alpha",
+ "a chatbot fine-tuned from Mistral by Hugging Face",
+)
+register_model_info(
+ ["qwen-14b-chat"],
+ "Qwen",
+ "https://huggingface.co./Qwen/Qwen-14B-Chat",
+ "a large language model by Alibaba Cloud",
+)
+register_model_info(
+ ["codellama-34b-instruct", "codellama-13b-instruct", "codellama-7b-instruct"],
+ "Code Llama",
+ "https://ai.meta.com/blog/code-llama-large-language-model-coding/",
+ "open foundation models for code by Meta",
+)
+register_model_info(
+ ["wizardlm-70b", "wizardlm-30b", "wizardlm-13b"],
+ "WizardLM",
+ "https://github.com/nlpxucan/WizardLM",
+ "an instruction-following LLM using evol-instruct by Microsoft",
+)
+register_model_info(
+ ["wizardcoder-15b-v1.0"],
+ "WizardLM",
+ "https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder",
+ "Empowering Code Large Language Models with Evol-Instruct",
+)
+register_model_info(
+ ["mpt-7b-chat", "mpt-30b-chat"],
+ "MPT-Chat",
+ "https://www.mosaicml.com/blog/mpt-30b",
+ "a chatbot fine-tuned from MPT by MosaicML",
+)
+register_model_info(
+ ["guanaco-33b", "guanaco-65b"],
+ "Guanaco",
+ "https://github.com/artidoro/qlora",
+ "a model fine-tuned with QLoRA by UW",
+)
+register_model_info(
+ ["gpt4all-13b-snoozy"],
+ "GPT4All-Snoozy",
+ "https://github.com/nomic-ai/gpt4all",
+ "a finetuned LLaMA model on assistant style data by Nomic AI",
+)
+register_model_info(
+ ["koala-13b"],
+ "Koala",
+ "https://bair.berkeley.edu/blog/2023/04/03/koala",
+ "a dialogue model for academic research by BAIR",
+)
+register_model_info(
+ ["RWKV-4-Raven-14B"],
+ "RWKV-4-Raven",
+ "https://huggingface.co./BlinkDL/rwkv-4-raven",
+ "an RNN with transformer-level LLM performance",
+)
+register_model_info(
+ ["chatglm-6b", "chatglm2-6b"],
+ "ChatGLM",
+ "https://chatglm.cn/blog",
+ "an open bilingual dialogue language model by Tsinghua University",
+)
+register_model_info(
+ ["alpaca-13b"],
+ "Alpaca",
+ "https://crfm.stanford.edu/2023/03/13/alpaca.html",
+ "a model fine-tuned from LLaMA on instruction-following demonstrations by Stanford",
+)
+register_model_info(
+ ["oasst-pythia-12b"],
+ "OpenAssistant (oasst)",
+ "https://open-assistant.io",
+ "an Open Assistant for everyone by LAION",
+)
+register_model_info(
+ ["oasst-sft-7-llama-30b"],
+ "OpenAssistant (oasst)",
+ "https://open-assistant.io",
+ "an Open Assistant for everyone by LAION",
+)
+register_model_info(
+ ["openchat-3.5"],
+ "OpenChat 3.5",
+ "https://github.com/imoneoi/openchat",
+ "OpenChat 3.5 is a versatile, open-source language model fine-tuned using C-RLFT",
+)
+register_model_info(
+ ["llama-7b", "llama-13b"],
+ "LLaMA",
+ "https://arxiv.org/abs/2302.13971",
+ "open and efficient foundation language models by Meta",
+)
+register_model_info(
+ ["open-llama-7b-v2-open-instruct", "open-llama-7b-open-instruct"],
+ "Open LLaMa (Open Instruct)",
+ "https://medium.com/vmware-data-ml-blog/starter-llm-for-the-enterprise-instruction-tuning-openllama-7b-d05fc3bbaccc",
+ "Open LLaMa fine-tuned on instruction-following data by VMware",
+)
+register_model_info(
+ ["dolly-v2-12b"],
+ "Dolly",
+ "https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm",
+ "an instruction-tuned open large language model by Databricks",
+)
+register_model_info(
+ ["stablelm-tuned-alpha-7b"],
+ "StableLM",
+ "https://github.com/stability-AI/stableLM",
+ "Stability AI language models",
+)
+register_model_info(
+ ["codet5p-6b"],
+ "CodeT5p-6b",
+ "https://huggingface.co./Salesforce/codet5p-6b",
+ "Code completion model released by Salesforce",
+)
+register_model_info(
+ ["fastchat-t5-3b", "fastchat-t5-3b-v1.0"],
+ "FastChat-T5",
+ "https://huggingface.co./lmsys/fastchat-t5-3b-v1.0",
+ "a chat assistant fine-tuned from FLAN-T5 by LMSYS",
+)
+register_model_info(
+ ["phoenix-inst-chat-7b"],
+ "Phoenix-7B",
+ "https://huggingface.co./FreedomIntelligence/phoenix-inst-chat-7b",
+ "a multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)",
+)
+register_model_info(
+ ["realm-7b-v1"],
+ "ReaLM",
+ "https://github.com/FreedomIntelligence/ReaLM",
+ "A chatbot fine-tuned from LLaMA2 with data generated via iterative calls to UserGPT and ChatGPT by CUHK(SZ) and SRIBD.",
+)
+register_model_info(
+ ["billa-7b-sft"],
+ "BiLLa-7B-SFT",
+ "https://huggingface.co./Neutralzz/BiLLa-7B-SFT",
+ "an instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher",
+)
+register_model_info(
+ ["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"],
+ "h2oGPT-GM-7b",
+ "https://huggingface.co./h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2",
+ "an instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai",
+)
+register_model_info(
+ ["baize-v2-7b", "baize-v2-13b"],
+ "Baize v2",
+ "https://github.com/project-baize/baize-chatbot#v2",
+ "A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.",
+)
+register_model_info(
+ [
+ "airoboros-l2-7b-2.1",
+ "airoboros-l2-13b-2.1",
+ "airoboros-c34b-2.1",
+ "airoboros-l2-70b-2.1",
+ ],
+ "airoboros",
+ "https://huggingface.co./jondurbin/airoboros-l2-70b-2.1",
+ "an instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4",
+)
+register_model_info(
+ [
+ "spicyboros-7b-2.2",
+ "spicyboros-13b-2.2",
+ "spicyboros-70b-2.2",
+ ],
+ "spicyboros",
+ "https://huggingface.co./jondurbin/spicyboros-70b-2.2",
+ "de-aligned versions of the airoboros models",
+)
+register_model_info(
+ ["Robin-7b-v2", "Robin-13b-v2", "Robin-33b-v2"],
+ "Robin-v2",
+ "https://huggingface.co./OptimalScale/robin-7b-v2-delta",
+ "A chatbot fine-tuned from LLaMA-7b, achieving competitive performance on chitchat, commonsense reasoning and instruction-following tasks, by OptimalScale, HKUST.",
+)
+register_model_info(
+ ["manticore-13b-chat"],
+ "Manticore 13B Chat",
+ "https://huggingface.co./openaccess-ai-collective/manticore-13b-chat-pyg",
+ "A chatbot fine-tuned from LlaMa across several CoT and chat datasets.",
+)
+register_model_info(
+ ["redpajama-incite-7b-chat"],
+ "RedPajama-INCITE-7B-Chat",
+ "https://huggingface.co./togethercomputer/RedPajama-INCITE-7B-Chat",
+ "A chatbot fine-tuned from RedPajama-INCITE-7B-Base by Together",
+)
+register_model_info(
+ [
+ "falcon-7b",
+ "falcon-7b-instruct",
+ "falcon-40b",
+ "falcon-40b-instruct",
+ "falcon-180b",
+ "falcon-180b-chat",
+ ],
+ "Falcon",
+ "https://huggingface.co./tiiuae/falcon-180B",
+ "TII's flagship series of large language models",
+)
+register_model_info(
+ ["tigerbot-7b-sft"],
+ "Tigerbot",
+ "https://huggingface.co./TigerResearch/tigerbot-7b-sft",
+ "TigerBot is a large-scale language model (LLM) with multiple languages and tasks.",
+)
+register_model_info(
+ ["internlm-chat-7b", "internlm-chat-7b-8k"],
+ "InternLM",
+ "https://huggingface.co./internlm/internlm-chat-7b",
+ "InternLM is a multi-language large-scale language model (LLM), developed by SHLAB.",
+)
+register_model_info(
+ ["Qwen-7B-Chat"],
+ "Qwen",
+ "https://huggingface.co./Qwen/Qwen-7B-Chat",
+ "Qwen is a multi-language large-scale language model (LLM), developed by Damo Academy.",
+)
+register_model_info(
+ ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"],
+ "Llama2-Chinese",
+ "https://huggingface.co./FlagAlpha/Llama2-Chinese-13b-Chat",
+ "Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.",
+)
+register_model_info(
+ ["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"],
+ "Vigogne-Instruct",
+ "https://huggingface.co./bofenghuang/vigogne-2-7b-instruct",
+ "Vigogne-Instruct is a French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang",
+)
+register_model_info(
+ ["Vigogne-2-7B-Chat", "Vigogne-2-13B-Chat"],
+ "Vigogne-Chat",
+ "https://huggingface.co./bofenghuang/vigogne-2-7b-chat",
+ "Vigogne-Chat is a French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang",
+)
+register_model_info(
+ ["deluxe-chat-v1", "deluxe-chat-v1.1"],
+ "DeluxeChat",
+ "",
+ "Deluxe Chat",
+)
+register_model_info(
+ [
+ "Xwin-LM-7B-V0.1",
+ "Xwin-LM-13B-V0.1",
+ "Xwin-LM-70B-V0.1",
+ "Xwin-LM-7B-V0.2",
+ "Xwin-LM-13B-V0.2",
+ ],
+ "Xwin-LM",
+ "https://github.com/Xwin-LM/Xwin-LM",
+ "Chat models developed by Xwin-LM team",
+)
+
+register_model_info(
+ ["lemur-70b-chat"],
+ "Lemur-Chat",
+ "https://huggingface.co./OpenLemur/lemur-70b-chat-v1",
+ "an openly accessible language model optimized for both natural language and coding capabilities ",
+)
+
+register_model_info(
+ ["Mistral-7B-OpenOrca"],
+ "Open-Orca",
+ "https://huggingface.co./Open-Orca/Mistral-7B-OpenOrca",
+ "A fine-tune of [Mistral 7B](https://huggingface.co./mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co./datasets/Open-Orca/OpenOrca)",
+)
+
+register_model_info(
+ [
+ "AquilaChat-7B",
+ "AquilaChat2-7B",
+ "AquilaChat2-34B",
+ ],
+ "Aquila-Chat",
+ "https://huggingface.co./BAAI/AquilaChat2-34B",
+ "Chat models developed by BAAI team",
+)
diff --git a/fastchat/model/model_xfastertransformer.py b/fastchat/model/model_xfastertransformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..54890b1ca4977f4243cca46cb7c78114a3b2e5d6
--- /dev/null
+++ b/fastchat/model/model_xfastertransformer.py
@@ -0,0 +1,81 @@
+import gc
+from threading import Thread
+
+import torch
+from transformers import TextIteratorStreamer
+
+
+@torch.inference_mode()
+def generate_stream_xft(
+ model,
+ tokenizer,
+ params,
+ device,
+ context_len=8192,
+ stream_interval=2,
+ judge_sent_end=False,
+):
+ prompt = params["prompt"]
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
+
+ # unused now, and placehold for future.
+ # temperature = float(params.get("temperature", 1.0))
+ # top_p = float(params.get("top_p", 1.0))
+
+ max_new_tokens = int(params.get("max_new_tokens", 4096))
+ echo = params.get("echo", True)
+
+ inputs = tokenizer(
+ prompt, return_tensors="pt", padding=model.config.padding
+ ).input_ids
+ input_echo_len = len(inputs[0])
+ max_len = max_new_tokens + input_echo_len
+
+ decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
+ generation_kwargs = {
+ "input_ids": inputs,
+ "streamer": streamer,
+ "max_length": max_len,
+ "num_beams": model.config.beam_width,
+ "length_penalty": repetition_penalty,
+ "num_return_sequences": model.config.num_return_sequences,
+ "early_stopping": model.config.early_stopping,
+ "eos_token_id": model.config.eos_token_id,
+ "pad_token_id": model.config.pad_token_id,
+ }
+
+ thread = Thread(target=model.model.generate, kwargs=generation_kwargs)
+ thread.start()
+ if echo:
+ # means keep the prompt
+ output = prompt
+ else:
+ output = ""
+ i = 0
+ for i, new_text in enumerate(streamer):
+ output += new_text
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": None,
+ }
+ output = output.strip()
+ if i == max_new_tokens - 1:
+ finish_reason = "length"
+ else:
+ finish_reason = "stop"
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": finish_reason,
+ }
+ gc.collect()
diff --git a/fastchat/model/monkey_patch_non_inplace.py b/fastchat/model/monkey_patch_non_inplace.py
new file mode 100644
index 0000000000000000000000000000000000000000..413dd3b30500c788abb19e5742447237ba2b1738
--- /dev/null
+++ b/fastchat/model/monkey_patch_non_inplace.py
@@ -0,0 +1,119 @@
+"""
+Monkey patch the llama implementation in the huggingface/transformers library.
+Avoid bugs in mps backend by not using in-place operations.
+"""
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+import transformers
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2].clone()
+ x2 = x[..., x.shape[-1] // 2 :].clone()
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
+ self.head_dim
+ )
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
+ )
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+ query_states.dtype
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def replace_llama_attn_with_non_inplace_operations():
+ """Avoid bugs in mps backend by not using in-place operations."""
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/fastchat/model/rwkv_model.py b/fastchat/model/rwkv_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdbc14584bfd1ec90e8478b4e55f07e8ec89a967
--- /dev/null
+++ b/fastchat/model/rwkv_model.py
@@ -0,0 +1,76 @@
+import os
+from types import SimpleNamespace
+import warnings
+
+import torch
+
+os.environ["RWKV_JIT_ON"] = "1"
+os.environ["RWKV_CUDA_ON"] = "1"
+
+from rwkv.model import RWKV
+from rwkv.utils import PIPELINE, PIPELINE_ARGS
+
+
+class RwkvModel:
+ def __init__(self, model_path):
+ warnings.warn(
+ "Experimental support. Please use ChatRWKV if you want to chat with RWKV"
+ )
+ self.config = SimpleNamespace(is_encoder_decoder=False)
+ self.model = RWKV(model=model_path, strategy="cuda fp16")
+ # two GPUs
+ # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16")
+
+ self.tokenizer = None
+ self.model_path = model_path
+
+ def to(self, target):
+ assert target == "cuda"
+
+ def __call__(self, input_ids, use_cache, past_key_values=None):
+ assert use_cache == True
+ input_ids = input_ids[0].detach().cpu().numpy()
+ # print(input_ids)
+ logits, state = self.model.forward(input_ids, past_key_values)
+ # print(logits)
+ logits = logits.unsqueeze(0).unsqueeze(0)
+ out = SimpleNamespace(logits=logits, past_key_values=state)
+ return out
+
+ def generate(
+ self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0
+ ):
+ # This function is used by fastchat.llm_judge.
+ # Because RWKV does not support huggingface generation API,
+ # we reuse fastchat.serve.inference.generate_stream as a workaround.
+ from transformers import AutoTokenizer
+
+ from fastchat.serve.inference import generate_stream
+ from fastchat.conversation import get_conv_template
+
+ if self.tokenizer is None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ "EleutherAI/pythia-160m", use_fast=True
+ )
+ prompt = self.tokenizer.decode(input_ids[0].tolist())
+ conv = get_conv_template("rwkv")
+
+ gen_params = {
+ "model": self.model_path,
+ "prompt": prompt,
+ "temperature": temperature,
+ "repetition_penalty": repetition_penalty,
+ "max_new_tokens": max_new_tokens,
+ "stop": conv.stop_str,
+ "stop_token_ids": conv.stop_token_ids,
+ "echo": False,
+ }
+ res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda")
+
+ for res in res_iter:
+ pass
+
+ output = res["text"]
+ output_ids = self.tokenizer.encode(output)
+
+ return [input_ids[0].tolist() + output_ids]
diff --git a/fastchat/model/upload_hub.py b/fastchat/model/upload_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1519652e6d90479d60054008d8d7e371b16356e
--- /dev/null
+++ b/fastchat/model/upload_hub.py
@@ -0,0 +1,45 @@
+"""
+Upload weights to huggingface.
+
+Usage:
+python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3
+"""
+import argparse
+import tempfile
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+
+def upload_hub(model_path, hub_repo_id, component, private):
+ if component == "all":
+ components = ["model", "tokenizer"]
+ else:
+ components = [component]
+
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private}
+
+ if "model" in components:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ )
+ with tempfile.TemporaryDirectory() as tmp_path:
+ model.save_pretrained(tmp_path, **kwargs)
+
+ if "tokenizer" in components:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ with tempfile.TemporaryDirectory() as tmp_path:
+ tokenizer.save_pretrained(tmp_path, **kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, required=True)
+ parser.add_argument("--hub-repo-id", type=str, required=True)
+ parser.add_argument(
+ "--component", type=str, choices=["all", "model", "tokenizer"], default="all"
+ )
+ parser.add_argument("--private", action="store_true")
+ args = parser.parse_args()
+
+ upload_hub(args.model_path, args.hub_repo_id, args.component, args.private)
diff --git a/fastchat/modules/__init__.py b/fastchat/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fastchat/modules/__pycache__/__init__.cpython-310.pyc b/fastchat/modules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4194579d9c6e2a4a267933a936d09a55c26d13d2
Binary files /dev/null and b/fastchat/modules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fastchat/modules/__pycache__/awq.cpython-310.pyc b/fastchat/modules/__pycache__/awq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95798e4dc3ddec4dd4c49903837a6caab740b617
Binary files /dev/null and b/fastchat/modules/__pycache__/awq.cpython-310.pyc differ
diff --git a/fastchat/modules/__pycache__/exllama.cpython-310.pyc b/fastchat/modules/__pycache__/exllama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5e3f1c2d6a337b2c688d0afe02f5e3586683b75
Binary files /dev/null and b/fastchat/modules/__pycache__/exllama.cpython-310.pyc differ
diff --git a/fastchat/modules/__pycache__/gptq.cpython-310.pyc b/fastchat/modules/__pycache__/gptq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7be276913fcbb187a4f73c30472b408bbc9fbc2d
Binary files /dev/null and b/fastchat/modules/__pycache__/gptq.cpython-310.pyc differ
diff --git a/fastchat/modules/__pycache__/xfastertransformer.cpython-310.pyc b/fastchat/modules/__pycache__/xfastertransformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e347eba83dfe735b5638348c07ca21d05399cc4
Binary files /dev/null and b/fastchat/modules/__pycache__/xfastertransformer.cpython-310.pyc differ
diff --git a/fastchat/modules/awq.py b/fastchat/modules/awq.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f27be85c09e2394bd821cc1ce236f46c429d4bc
--- /dev/null
+++ b/fastchat/modules/awq.py
@@ -0,0 +1,85 @@
+from dataclasses import dataclass, field
+from pathlib import Path
+import sys
+
+import torch
+from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils
+
+
+@dataclass
+class AWQConfig:
+ ckpt: str = field(
+ default=None,
+ metadata={
+ "help": "Load quantized model. The path to the local AWQ checkpoint."
+ },
+ )
+ wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"})
+ groupsize: int = field(
+ default=-1,
+ metadata={"help": "Groupsize to use for quantization; default uses full row."},
+ )
+
+
+def load_awq_quantized(model_name, awq_config: AWQConfig, device):
+ print("Loading AWQ quantized model...")
+
+ try:
+ from tinychat.utils import load_quant
+ from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
+ except ImportError as e:
+ print(f"Error: Failed to import tinychat. {e}")
+ print("Please double check if you have successfully installed AWQ")
+ print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md")
+ sys.exit(-1)
+
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_name, use_fast=False, trust_remote_code=True
+ )
+
+ def skip(*args, **kwargs):
+ pass
+
+ torch.nn.init.kaiming_uniform_ = skip
+ torch.nn.init.kaiming_normal_ = skip
+ torch.nn.init.uniform_ = skip
+ torch.nn.init.normal_ = skip
+ modeling_utils._init_weights = False
+
+ torch.set_default_dtype(torch.half)
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
+
+ if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]):
+ model = load_quant.load_awq_llama_fast(
+ model,
+ find_awq_ckpt(awq_config),
+ awq_config.wbits,
+ awq_config.groupsize,
+ device,
+ )
+ make_quant_attn(model, device)
+ make_quant_norm(model)
+ make_fused_mlp(model)
+ else:
+ model = load_quant.load_awq_model(
+ model,
+ find_awq_ckpt(awq_config),
+ awq_config.wbits,
+ awq_config.groupsize,
+ device,
+ )
+ return model, tokenizer
+
+
+def find_awq_ckpt(awq_config: AWQConfig):
+ if Path(awq_config.ckpt).is_file():
+ return awq_config.ckpt
+
+ for ext in ["*.pt", "*.safetensors"]:
+ matched_result = sorted(Path(awq_config.ckpt).glob(ext))
+ if len(matched_result) > 0:
+ return str(matched_result[-1])
+
+ print("Error: AWQ checkpoint not found")
+ sys.exit(1)
diff --git a/fastchat/modules/exllama.py b/fastchat/modules/exllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bddaa91d84c63b9b6478bc79e340191d4e0cc3a
--- /dev/null
+++ b/fastchat/modules/exllama.py
@@ -0,0 +1,46 @@
+from dataclasses import dataclass, field
+import sys
+
+
+@dataclass
+class ExllamaConfig:
+ max_seq_len: int
+ gpu_split: str = None
+
+
+class ExllamaModel:
+ def __init__(self, exllama_model, exllama_cache):
+ self.model = exllama_model
+ self.cache = exllama_cache
+ self.config = self.model.config
+
+
+def load_exllama_model(model_path, exllama_config: ExllamaConfig):
+ try:
+ from exllamav2 import (
+ ExLlamaV2Config,
+ ExLlamaV2Tokenizer,
+ ExLlamaV2,
+ ExLlamaV2Cache,
+ )
+ except ImportError as e:
+ print(f"Error: Failed to load Exllamav2. {e}")
+ sys.exit(-1)
+
+ exllamav2_config = ExLlamaV2Config()
+ exllamav2_config.model_dir = model_path
+ exllamav2_config.prepare()
+ exllamav2_config.max_seq_len = exllama_config.max_seq_len
+
+ exllama_model = ExLlamaV2(exllamav2_config)
+ tokenizer = ExLlamaV2Tokenizer(exllamav2_config)
+
+ split = None
+ if exllama_config.gpu_split:
+ split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")]
+ exllama_model.load(split)
+
+ exllama_cache = ExLlamaV2Cache(exllama_model)
+ model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache)
+
+ return model, tokenizer
diff --git a/fastchat/modules/gptq.py b/fastchat/modules/gptq.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe0a220c0cfb227271fbb4d1e7c4eca636b10d1c
--- /dev/null
+++ b/fastchat/modules/gptq.py
@@ -0,0 +1,75 @@
+from dataclasses import dataclass, field
+import os
+from os.path import isdir, isfile
+from pathlib import Path
+import sys
+
+from transformers import AutoTokenizer
+
+
+@dataclass
+class GptqConfig:
+ ckpt: str = field(
+ default=None,
+ metadata={
+ "help": "Load quantized model. The path to the local GPTQ checkpoint."
+ },
+ )
+ wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"})
+ groupsize: int = field(
+ default=-1,
+ metadata={"help": "Groupsize to use for quantization; default uses full row."},
+ )
+ act_order: bool = field(
+ default=True,
+ metadata={"help": "Whether to apply the activation order GPTQ heuristic"},
+ )
+
+
+def load_gptq_quantized(model_name, gptq_config: GptqConfig):
+ print("Loading GPTQ quantized model...")
+
+ try:
+ script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+ module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa")
+
+ sys.path.insert(0, module_path)
+ from llama import load_quant
+ except ImportError as e:
+ print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
+ print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
+ sys.exit(-1)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
+ # only `fastest-inference-4bit` branch cares about `act_order`
+ if gptq_config.act_order:
+ model = load_quant(
+ model_name,
+ find_gptq_ckpt(gptq_config),
+ gptq_config.wbits,
+ gptq_config.groupsize,
+ act_order=gptq_config.act_order,
+ )
+ else:
+ # other branches
+ model = load_quant(
+ model_name,
+ find_gptq_ckpt(gptq_config),
+ gptq_config.wbits,
+ gptq_config.groupsize,
+ )
+
+ return model, tokenizer
+
+
+def find_gptq_ckpt(gptq_config: GptqConfig):
+ if Path(gptq_config.ckpt).is_file():
+ return gptq_config.ckpt
+
+ for ext in ["*.pt", "*.safetensors"]:
+ matched_result = sorted(Path(gptq_config.ckpt).glob(ext))
+ if len(matched_result) > 0:
+ return str(matched_result[-1])
+
+ print("Error: gptq checkpoint not found")
+ sys.exit(1)
diff --git a/fastchat/modules/xfastertransformer.py b/fastchat/modules/xfastertransformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b49bea4cd5c9afd723318daaa5c10dcb309b776
--- /dev/null
+++ b/fastchat/modules/xfastertransformer.py
@@ -0,0 +1,46 @@
+from dataclasses import dataclass
+import sys
+
+
+@dataclass
+class XftConfig:
+ max_seq_len: int = 4096
+ beam_width: int = 1
+ eos_token_id: int = -1
+ pad_token_id: int = -1
+ num_return_sequences: int = 1
+ is_encoder_decoder: bool = False
+ padding: bool = True
+ early_stopping: bool = False
+ data_type: str = "bf16_fp16"
+
+
+class XftModel:
+ def __init__(self, xft_model, xft_config):
+ self.model = xft_model
+ self.config = xft_config
+
+
+def load_xft_model(model_path, xft_config: XftConfig):
+ try:
+ import xfastertransformer
+ from transformers import AutoTokenizer
+ except ImportError as e:
+ print(f"Error: Failed to load xFasterTransformer. {e}")
+ sys.exit(-1)
+
+ if xft_config.data_type is None or xft_config.data_type == "":
+ data_type = "bf16_fp16"
+ else:
+ data_type = xft_config.data_type
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=False, padding_side="left", trust_remote_code=True
+ )
+ xft_model = xfastertransformer.AutoModel.from_pretrained(
+ model_path, dtype=data_type
+ )
+ model = XftModel(xft_model=xft_model, xft_config=xft_config)
+ if model.model.rank > 0:
+ while True:
+ model.model.generate()
+ return model, tokenizer
diff --git a/fastchat/protocol/api_protocol.py b/fastchat/protocol/api_protocol.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dc99449dca6551f8eb6c51dfe86eca28ea6a6be
--- /dev/null
+++ b/fastchat/protocol/api_protocol.py
@@ -0,0 +1,172 @@
+from typing import Literal, Optional, List, Dict, Any, Union
+
+import time
+
+import shortuuid
+from pydantic import BaseModel, Field
+
+
+class ErrorResponse(BaseModel):
+ object: str = "error"
+ message: str
+ code: int
+
+
+class ModelPermission(BaseModel):
+ id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}")
+ object: str = "model_permission"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ allow_create_engine: bool = False
+ allow_sampling: bool = True
+ allow_logprobs: bool = True
+ allow_search_indices: bool = True
+ allow_view: bool = True
+ allow_fine_tuning: bool = False
+ organization: str = "*"
+ group: Optional[str] = None
+ is_blocking: str = False
+
+
+class ModelCard(BaseModel):
+ id: str
+ object: str = "model"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ owned_by: str = "fastchat"
+ root: Optional[str] = None
+ parent: Optional[str] = None
+ permission: List[ModelPermission] = []
+
+
+class ModelList(BaseModel):
+ object: str = "list"
+ data: List[ModelCard] = []
+
+
+class UsageInfo(BaseModel):
+ prompt_tokens: int = 0
+ total_tokens: int = 0
+ completion_tokens: Optional[int] = 0
+
+
+class APIChatCompletionRequest(BaseModel):
+ model: str
+ messages: Union[str, List[Dict[str, str]]]
+ temperature: Optional[float] = 0.7
+ top_p: Optional[float] = 1.0
+ top_k: Optional[int] = -1
+ n: Optional[int] = 1
+ max_tokens: Optional[int] = None
+ stop: Optional[Union[str, List[str]]] = None
+ stream: Optional[bool] = False
+ user: Optional[str] = None
+ repetition_penalty: Optional[float] = 1.0
+ frequency_penalty: Optional[float] = 0.0
+ presence_penalty: Optional[float] = 0.0
+
+
+class ChatMessage(BaseModel):
+ role: str
+ content: str
+
+
+class ChatCompletionResponseChoice(BaseModel):
+ index: int
+ message: ChatMessage
+ finish_reason: Optional[Literal["stop", "length"]] = None
+
+
+class ChatCompletionResponse(BaseModel):
+ id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
+ object: str = "chat.completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseChoice]
+ usage: UsageInfo
+
+
+class DeltaMessage(BaseModel):
+ role: Optional[str] = None
+ content: Optional[str] = None
+
+
+class ChatCompletionResponseStreamChoice(BaseModel):
+ index: int
+ delta: DeltaMessage
+ finish_reason: Optional[Literal["stop", "length"]] = None
+
+
+class ChatCompletionStreamResponse(BaseModel):
+ id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
+ object: str = "chat.completion.chunk"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseStreamChoice]
+
+
+class APITokenCheckRequestItem(BaseModel):
+ model: str
+ prompt: str
+ max_tokens: int
+
+
+class APITokenCheckRequest(BaseModel):
+ prompts: List[APITokenCheckRequestItem]
+
+
+class APITokenCheckResponseItem(BaseModel):
+ fits: bool
+ tokenCount: int
+ contextLength: int
+
+
+class APITokenCheckResponse(BaseModel):
+ prompts: List[APITokenCheckResponseItem]
+
+
+class CompletionRequest(BaseModel):
+ model: str
+ prompt: Union[str, List[Any]]
+ suffix: Optional[str] = None
+ temperature: Optional[float] = 0.7
+ n: Optional[int] = 1
+ max_tokens: Optional[int] = 16
+ stop: Optional[Union[str, List[str]]] = None
+ stream: Optional[bool] = False
+ top_p: Optional[float] = 1.0
+ top_k: Optional[int] = -1
+ logprobs: Optional[int] = None
+ echo: Optional[bool] = False
+ presence_penalty: Optional[float] = 0.0
+ frequency_penalty: Optional[float] = 0.0
+ user: Optional[str] = None
+
+
+class CompletionResponseChoice(BaseModel):
+ index: int
+ text: str
+ logprobs: Optional[int] = None
+ finish_reason: Optional[Literal["stop", "length"]] = None
+
+
+class CompletionResponse(BaseModel):
+ id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
+ object: str = "text_completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[CompletionResponseChoice]
+ usage: UsageInfo
+
+
+class CompletionResponseStreamChoice(BaseModel):
+ index: int
+ text: str
+ logprobs: Optional[float] = None
+ finish_reason: Optional[Literal["stop", "length"]] = None
+
+
+class CompletionStreamResponse(BaseModel):
+ id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
+ object: str = "text_completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[CompletionResponseStreamChoice]
diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a0063393ddd197ad2010fe2993f178aea8d3eda
--- /dev/null
+++ b/fastchat/protocol/openai_api_protocol.py
@@ -0,0 +1,195 @@
+from typing import Literal, Optional, List, Dict, Any, Union
+
+import time
+
+import shortuuid
+from pydantic import BaseModel, Field
+
+
+class ErrorResponse(BaseModel):
+ object: str = "error"
+ message: str
+ code: int
+
+
+class ModelPermission(BaseModel):
+ id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}")
+ object: str = "model_permission"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ allow_create_engine: bool = False
+ allow_sampling: bool = True
+ allow_logprobs: bool = True
+ allow_search_indices: bool = True
+ allow_view: bool = True
+ allow_fine_tuning: bool = False
+ organization: str = "*"
+ group: Optional[str] = None
+ is_blocking: str = False
+
+
+class ModelCard(BaseModel):
+ id: str
+ object: str = "model"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ owned_by: str = "fastchat"
+ root: Optional[str] = None
+ parent: Optional[str] = None
+ permission: List[ModelPermission] = []
+
+
+class ModelList(BaseModel):
+ object: str = "list"
+ data: List[ModelCard] = []
+
+
+class UsageInfo(BaseModel):
+ prompt_tokens: int = 0
+ total_tokens: int = 0
+ completion_tokens: Optional[int] = 0
+
+
+class LogProbs(BaseModel):
+ text_offset: List[int] = Field(default_factory=list)
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
+ tokens: List[str] = Field(default_factory=list)
+ top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
+
+
+class ChatCompletionRequest(BaseModel):
+ model: str
+ messages: Union[str, List[Dict[str, str]]]
+ temperature: Optional[float] = 0.7
+ top_p: Optional[float] = 1.0
+ top_k: Optional[int] = -1
+ n: Optional[int] = 1
+ max_tokens: Optional[int] = None
+ stop: Optional[Union[str, List[str]]] = None
+ stream: Optional[bool] = False
+ presence_penalty: Optional[float] = 0.0
+ frequency_penalty: Optional[float] = 0.0
+ user: Optional[str] = None
+
+
+class ChatMessage(BaseModel):
+ role: str
+ content: str
+
+
+class ChatCompletionResponseChoice(BaseModel):
+ index: int
+ message: ChatMessage
+ finish_reason: Optional[Literal["stop", "length"]] = None
+
+
+class ChatCompletionResponse(BaseModel):
+ id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
+ object: str = "chat.completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseChoice]
+ usage: UsageInfo
+
+
+class DeltaMessage(BaseModel):
+ role: Optional[str] = None
+ content: Optional[str] = None
+
+
+class ChatCompletionResponseStreamChoice(BaseModel):
+ index: int
+ delta: DeltaMessage
+ finish_reason: Optional[Literal["stop", "length"]] = None
+
+
+class ChatCompletionStreamResponse(BaseModel):
+ id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
+ object: str = "chat.completion.chunk"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseStreamChoice]
+
+
+class TokenCheckRequestItem(BaseModel):
+ model: str
+ prompt: str
+ max_tokens: int
+
+
+class TokenCheckRequest(BaseModel):
+ prompts: List[TokenCheckRequestItem]
+
+
+class TokenCheckResponseItem(BaseModel):
+ fits: bool
+ tokenCount: int
+ contextLength: int
+
+
+class TokenCheckResponse(BaseModel):
+ prompts: List[TokenCheckResponseItem]
+
+
+class EmbeddingsRequest(BaseModel):
+ model: Optional[str] = None
+ engine: Optional[str] = None
+ input: Union[str, List[Any]]
+ user: Optional[str] = None
+ encoding_format: Optional[str] = None
+
+
+class EmbeddingsResponse(BaseModel):
+ object: str = "list"
+ data: List[Dict[str, Any]]
+ model: str
+ usage: UsageInfo
+
+
+class CompletionRequest(BaseModel):
+ model: str
+ prompt: Union[str, List[Any]]
+ suffix: Optional[str] = None
+ temperature: Optional[float] = 0.7
+ n: Optional[int] = 1
+ max_tokens: Optional[int] = 16
+ stop: Optional[Union[str, List[str]]] = None
+ stream: Optional[bool] = False
+ top_p: Optional[float] = 1.0
+ top_k: Optional[int] = -1
+ logprobs: Optional[int] = None
+ echo: Optional[bool] = False
+ presence_penalty: Optional[float] = 0.0
+ frequency_penalty: Optional[float] = 0.0
+ user: Optional[str] = None
+ use_beam_search: Optional[bool] = False
+ best_of: Optional[int] = None
+
+
+class CompletionResponseChoice(BaseModel):
+ index: int
+ text: str
+ logprobs: Optional[LogProbs] = None
+ finish_reason: Optional[Literal["stop", "length"]] = None
+
+
+class CompletionResponse(BaseModel):
+ id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
+ object: str = "text_completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[CompletionResponseChoice]
+ usage: UsageInfo
+
+
+class CompletionResponseStreamChoice(BaseModel):
+ index: int
+ text: str
+ logprobs: Optional[LogProbs] = None
+ finish_reason: Optional[Literal["stop", "length"]] = None
+
+
+class CompletionStreamResponse(BaseModel):
+ id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
+ object: str = "text_completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[CompletionResponseStreamChoice]
diff --git a/fastchat/serve/README.md b/fastchat/serve/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f01612b9c5e2a28f216e4d45f668d834b75c06f8
--- /dev/null
+++ b/fastchat/serve/README.md
@@ -0,0 +1,6 @@
+---
+title: demo_test
+app_file: gradio_web_server.py
+sdk: gradio
+sdk_version: 3.45.0
+---
diff --git a/fastchat/serve/__init__.py b/fastchat/serve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fastchat/serve/__pycache__/__init__.cpython-310.pyc b/fastchat/serve/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d264ee86b93bee08578afcdc7208f94c2646f31
Binary files /dev/null and b/fastchat/serve/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fastchat/serve/__pycache__/__init__.cpython-311.pyc b/fastchat/serve/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f388c132ae6e0ecca6b4fc940f10f593a6d0cfc
Binary files /dev/null and b/fastchat/serve/__pycache__/__init__.cpython-311.pyc differ
diff --git a/fastchat/serve/__pycache__/api_provider.cpython-310.pyc b/fastchat/serve/__pycache__/api_provider.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da748235a77b7a4b536f15110469a198967fc952
Binary files /dev/null and b/fastchat/serve/__pycache__/api_provider.cpython-310.pyc differ
diff --git a/fastchat/serve/__pycache__/base_model_worker.cpython-310.pyc b/fastchat/serve/__pycache__/base_model_worker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa6afe37df86561fd3d0f4a5bc472310916224bb
Binary files /dev/null and b/fastchat/serve/__pycache__/base_model_worker.cpython-310.pyc differ
diff --git a/fastchat/serve/__pycache__/cli.cpython-310.pyc b/fastchat/serve/__pycache__/cli.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19a460d5116b39b75cecf0909349e00af8c83c2f
Binary files /dev/null and b/fastchat/serve/__pycache__/cli.cpython-310.pyc differ
diff --git a/fastchat/serve/__pycache__/cli.cpython-311.pyc b/fastchat/serve/__pycache__/cli.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c9bfe97259e32b96368cb05fe6928a420b3d963
Binary files /dev/null and b/fastchat/serve/__pycache__/cli.cpython-311.pyc differ
diff --git a/fastchat/serve/__pycache__/controller.cpython-310.pyc b/fastchat/serve/__pycache__/controller.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d476553e5f9bc9c7aa60ed6af5beb0111bdc44d
Binary files /dev/null and b/fastchat/serve/__pycache__/controller.cpython-310.pyc differ
diff --git a/fastchat/serve/__pycache__/gradio_web_server.cpython-310.pyc b/fastchat/serve/__pycache__/gradio_web_server.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7aa2fcdb21b9fd4940a345fc25d038d7e7e2b898
Binary files /dev/null and b/fastchat/serve/__pycache__/gradio_web_server.cpython-310.pyc differ
diff --git a/fastchat/serve/__pycache__/inference.cpython-310.pyc b/fastchat/serve/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1d49d8c4d9fdf6f6e05f8ffc97fb53286cfff4e
Binary files /dev/null and b/fastchat/serve/__pycache__/inference.cpython-310.pyc differ
diff --git a/fastchat/serve/__pycache__/model_worker.cpython-310.pyc b/fastchat/serve/__pycache__/model_worker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..373080e38561629f28c7fe09b6b3b04b633073bd
Binary files /dev/null and b/fastchat/serve/__pycache__/model_worker.cpython-310.pyc differ
diff --git a/fastchat/serve/__pycache__/test_message.cpython-310.pyc b/fastchat/serve/__pycache__/test_message.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..981fe6f340d6ec7f63f1c9c5e0d68e83793b3073
Binary files /dev/null and b/fastchat/serve/__pycache__/test_message.cpython-310.pyc differ
diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dbb8a69032a5ac3378a8de34774a2dcdf304223
--- /dev/null
+++ b/fastchat/serve/api_provider.py
@@ -0,0 +1,130 @@
+"""Call API providers."""
+
+import os
+import random
+import time
+
+from fastchat.utils import build_logger
+from fastchat.constants import WORKER_API_TIMEOUT
+
+
+logger = build_logger("gradio_web_server", "gradio_web_server.log")
+
+
+def openai_api_stream_iter(
+ model_name,
+ messages,
+ temperature,
+ top_p,
+ max_new_tokens,
+ api_base=None,
+ api_key=None,
+):
+ import openai
+
+ openai.api_base = api_base or "https://api.openai.com/v1"
+ openai.api_key = api_key or os.environ["OPENAI_API_KEY"]
+ if model_name == "gpt-4-turbo":
+ model_name = "gpt-4-1106-preview"
+
+ # Make requests
+ gen_params = {
+ "model": model_name,
+ "prompt": messages,
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_new_tokens": max_new_tokens,
+ }
+ logger.info(f"==== request ====\n{gen_params}")
+
+ res = openai.ChatCompletion.create(
+ model=model_name,
+ messages=messages,
+ temperature=temperature,
+ max_tokens=max_new_tokens,
+ stream=True,
+ )
+ text = ""
+ for chunk in res:
+ text += chunk["choices"][0]["delta"].get("content", "")
+ data = {
+ "text": text,
+ "error_code": 0,
+ }
+ yield data
+
+
+def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
+ import anthropic
+
+ c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
+
+ # Make requests
+ gen_params = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_new_tokens": max_new_tokens,
+ }
+ logger.info(f"==== request ====\n{gen_params}")
+
+ res = c.completions.create(
+ prompt=prompt,
+ stop_sequences=[anthropic.HUMAN_PROMPT],
+ max_tokens_to_sample=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ model=model_name,
+ stream=True,
+ )
+ text = ""
+ for chunk in res:
+ text += chunk.completion
+ data = {
+ "text": text,
+ "error_code": 0,
+ }
+ yield data
+
+
+def init_palm_chat(model_name):
+ import vertexai # pip3 install google-cloud-aiplatform
+ from vertexai.preview.language_models import ChatModel
+
+ project_id = os.environ["GCP_PROJECT_ID"]
+ location = "us-central1"
+ vertexai.init(project=project_id, location=location)
+
+ chat_model = ChatModel.from_pretrained(model_name)
+ chat = chat_model.start_chat(examples=[])
+ return chat
+
+
+def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens):
+ parameters = {
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_output_tokens": max_new_tokens,
+ }
+ gen_params = {
+ "model": "palm-2",
+ "prompt": message,
+ }
+ gen_params.update(parameters)
+ logger.info(f"==== request ====\n{gen_params}")
+
+ response = chat.send_message(message, **parameters)
+ content = response.text
+
+ pos = 0
+ while pos < len(content):
+ # This is a fancy way to simulate token generation latency combined
+ # with a Poisson process.
+ pos += random.randint(10, 20)
+ time.sleep(random.expovariate(50))
+ data = {
+ "text": content[:pos],
+ "error_code": 0,
+ }
+ yield data
diff --git a/fastchat/serve/base_model_worker.py b/fastchat/serve/base_model_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a297ab7e440cf727526a60e1481bbeac286544c9
--- /dev/null
+++ b/fastchat/serve/base_model_worker.py
@@ -0,0 +1,239 @@
+import asyncio
+import threading
+import time
+from typing import List
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse, JSONResponse
+import requests
+
+from fastchat.constants import WORKER_HEART_BEAT_INTERVAL
+from fastchat.conversation import Conversation
+from fastchat.utils import pretty_print_semaphore, build_logger
+
+
+worker = None
+logger = None
+
+app = FastAPI()
+
+
+def heart_beat_worker(obj):
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ obj.send_heart_beat()
+
+
+class BaseModelWorker:
+ def __init__(
+ self,
+ controller_addr: str,
+ worker_addr: str,
+ worker_id: str,
+ model_path: str,
+ model_names: List[str],
+ limit_worker_concurrency: int,
+ conv_template: str = None,
+ ):
+ global logger, worker
+
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ self.model_names = model_names or [model_path.split("/")[-1]]
+ self.limit_worker_concurrency = limit_worker_concurrency
+ self.conv = self.make_conv_template(conv_template, model_path)
+ self.conv.sep_style = int(self.conv.sep_style)
+ self.tokenizer = None
+ self.context_len = None
+ self.call_ct = 0
+ self.semaphore = None
+
+ self.heart_beat_thread = None
+
+ if logger is None:
+ logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log")
+ if worker is None:
+ worker = self
+
+ def make_conv_template(
+ self,
+ conv_template: str = None,
+ model_path: str = None,
+ ) -> Conversation:
+ """
+ can be overrided to costomize the conversation template for different model workers.
+ """
+ from fastchat.conversation import get_conv_template
+ from fastchat.model.model_adapter import get_conversation_template
+
+ if conv_template:
+ conv = get_conv_template(conv_template)
+ else:
+ conv = get_conversation_template(model_path)
+ print(conv)
+ return conv
+
+ def init_heart_beat(self):
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_worker,
+ args=(self,),
+ daemon=True,
+ )
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {
+ "worker_name": self.worker_addr,
+ "check_heart_beat": True,
+ "worker_status": self.get_status(),
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(
+ f"Send heart beat. Models: {self.model_names}. "
+ f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
+ f"call_ct: {self.call_ct}. "
+ f"worker_id: {self.worker_id}. "
+ )
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(
+ url,
+ json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length(),
+ },
+ timeout=5,
+ )
+ exist = ret.json()["exist"]
+ break
+ except (requests.exceptions.RequestException, KeyError) as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if (
+ self.semaphore is None
+ or self.semaphore._value is None
+ or self.semaphore._waiters is None
+ ):
+ return 0
+ else:
+ return (
+ self.limit_worker_concurrency
+ - self.semaphore._value
+ + len(self.semaphore._waiters)
+ )
+
+ def get_status(self):
+ return {
+ "model_names": self.model_names,
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ def count_token(self, params):
+ prompt = params["prompt"]
+
+ try:
+ input_ids = self.tokenizer(prompt).input_ids
+ input_echo_len = len(input_ids)
+ except TypeError:
+ input_echo_len = self.tokenizer.num_tokens(prompt)
+
+ ret = {
+ "count": input_echo_len,
+ "error_code": 0,
+ }
+ return ret
+
+ def get_conv_template(self):
+ return {"conv": self.conv}
+
+ def generate_stream_gate(self, params):
+ raise NotImplementedError
+
+ def generate_gate(self, params):
+ raise NotImplementedError
+
+ def get_embeddings(self, params):
+ raise NotImplementedError
+
+
+def release_worker_semaphore():
+ worker.semaphore.release()
+
+
+def acquire_worker_semaphore():
+ if worker.semaphore is None:
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
+ return worker.semaphore.acquire()
+
+
+def create_background_tasks():
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(release_worker_semaphore)
+ return background_tasks
+
+
+@app.post("/worker_generate_stream")
+async def api_generate_stream(request: Request):
+ params = await request.json()
+ await acquire_worker_semaphore()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = create_background_tasks()
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_generate")
+async def api_generate(request: Request):
+ params = await request.json()
+ await acquire_worker_semaphore()
+ output = await asyncio.to_thread(worker.generate_gate, params)
+ release_worker_semaphore()
+ return JSONResponse(output)
+
+
+@app.post("/worker_get_embeddings")
+async def api_get_embeddings(request: Request):
+ params = await request.json()
+ await acquire_worker_semaphore()
+ embedding = worker.get_embeddings(params)
+ release_worker_semaphore()
+ return JSONResponse(content=embedding)
+
+
+@app.post("/worker_get_status")
+async def api_get_status(request: Request):
+ return worker.get_status()
+
+
+@app.post("/count_token")
+async def api_count_token(request: Request):
+ params = await request.json()
+ return worker.count_token(params)
+
+
+@app.post("/worker_get_conv_template")
+async def api_get_conv(request: Request):
+ return worker.get_conv_template()
+
+
+@app.post("/model_details")
+async def api_model_details(request: Request):
+ return {"context_length": worker.context_len}
diff --git a/fastchat/serve/cli.py b/fastchat/serve/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb815cde6a751e283307c6d0cd5bbbde2fd062ed
--- /dev/null
+++ b/fastchat/serve/cli.py
@@ -0,0 +1,313 @@
+"""
+Chat with a model with command line interface.
+
+Usage:
+python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5
+python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0
+
+Other commands:
+- Type "!!exit" or an empty line to exit.
+- Type "!!reset" to start a new conversation.
+- Type "!!remove" to remove the last prompt.
+- Type "!!regen" to regenerate the last message.
+- Type "!!save " to save the conversation history to a json file.
+- Type "!!load " to load a conversation history from a json file.
+"""
+import argparse
+import os
+import re
+import sys
+
+from prompt_toolkit import PromptSession
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+from prompt_toolkit.completion import WordCompleter
+from prompt_toolkit.history import InMemoryHistory
+from prompt_toolkit.key_binding import KeyBindings
+from rich.console import Console
+from rich.live import Live
+from rich.markdown import Markdown
+import torch
+
+from fastchat.model.model_adapter import add_model_args
+from fastchat.modules.awq import AWQConfig
+from fastchat.modules.exllama import ExllamaConfig
+from fastchat.modules.xfastertransformer import XftConfig
+from fastchat.modules.gptq import GptqConfig
+from fastchat.serve.inference import ChatIO, chat_loop
+from fastchat.utils import str_to_torch_dtype
+
+
+class SimpleChatIO(ChatIO):
+ def __init__(self, multiline: bool = False, prefix: str = ''):
+ self._multiline = multiline
+ self.prefix = prefix
+
+ def prompt_for_input(self, role) -> str:
+ if not self._multiline:
+ return input(f"{role}: {self.prefix}")
+
+ prompt_data = []
+ line = input(f"{role} [ctrl-d/z on empty line to end]: ")
+ while True:
+ prompt_data.append(line.strip())
+ try:
+ line = input()
+ except EOFError as e:
+ break
+ return f"\n{self.prefix}".join(prompt_data)
+
+ def prompt_for_output(self, role: str):
+ print(f"{role}: ", end="", flush=True)
+
+ def stream_output(self, output_stream):
+ pre = 0
+ for outputs in output_stream:
+ output_text = outputs["text"]
+ output_text = output_text.strip().split(" ")
+ now = len(output_text) - 1
+ if now > pre:
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
+ pre = now
+ print(" ".join(output_text[pre:]), flush=True)
+ return " ".join(output_text)
+
+ def print_output(self, text: str):
+ print(text)
+
+
+class RichChatIO(ChatIO):
+ bindings = KeyBindings()
+
+ @bindings.add("escape", "enter")
+ def _(event):
+ event.app.current_buffer.newline()
+
+ def __init__(self, multiline: bool = False, mouse: bool = False):
+ self._prompt_session = PromptSession(history=InMemoryHistory())
+ self._completer = WordCompleter(
+ words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"],
+ pattern=re.compile("$"),
+ )
+ self._console = Console()
+ self._multiline = multiline
+ self._mouse = mouse
+
+ def prompt_for_input(self, role) -> str:
+ self._console.print(f"[bold]{role}:")
+ # TODO(suquark): multiline input has some issues. fix it later.
+ prompt_input = self._prompt_session.prompt(
+ completer=self._completer,
+ multiline=False,
+ mouse_support=self._mouse,
+ auto_suggest=AutoSuggestFromHistory(),
+ key_bindings=self.bindings if self._multiline else None,
+ )
+ self._console.print()
+ return prompt_input
+
+ def prompt_for_output(self, role: str):
+ self._console.print(f"[bold]{role.replace('/', '|')}:")
+
+ def stream_output(self, output_stream):
+ """Stream output from a role."""
+ # TODO(suquark): the console flickers when there is a code block
+ # above it. We need to cut off "live" when a code block is done.
+
+ # Create a Live context for updating the console output
+ with Live(console=self._console, refresh_per_second=4) as live:
+ # Read lines from the stream
+ for outputs in output_stream:
+ if not outputs:
+ continue
+ text = outputs["text"]
+ # Render the accumulated text as Markdown
+ # NOTE: this is a workaround for the rendering "unstandard markdown"
+ # in rich. The chatbots output treat "\n" as a new line for
+ # better compatibility with real-world text. However, rendering
+ # in markdown would break the format. It is because standard markdown
+ # treat a single "\n" in normal text as a space.
+ # Our workaround is adding two spaces at the end of each line.
+ # This is not a perfect solution, as it would
+ # introduce trailing spaces (only) in code block, but it works well
+ # especially for console output, because in general the console does not
+ # care about trailing spaces.
+ lines = []
+ for line in text.splitlines():
+ lines.append(line)
+ if line.startswith("```"):
+ # Code block marker - do not add trailing spaces, as it would
+ # break the syntax highlighting
+ lines.append("\n")
+ else:
+ lines.append(" \n")
+ markdown = Markdown("".join(lines))
+ # Update the Live console output
+ live.update(markdown)
+ self._console.print()
+ return text
+
+ def print_output(self, text: str):
+ self.stream_output([{"text": text}])
+
+
+class ProgrammaticChatIO(ChatIO):
+ def prompt_for_input(self, role) -> str:
+ contents = ""
+ # `end_sequence` signals the end of a message. It is unlikely to occur in
+ # message content.
+ end_sequence = " __END_OF_A_MESSAGE_47582648__\n"
+ len_end = len(end_sequence)
+ while True:
+ if len(contents) >= len_end:
+ last_chars = contents[-len_end:]
+ if last_chars == end_sequence:
+ break
+ try:
+ char = sys.stdin.read(1)
+ contents = contents + char
+ except EOFError:
+ continue
+ contents = contents[:-len_end]
+ print(f"[!OP:{role}]: {contents}", flush=True)
+ return contents
+
+ def prompt_for_output(self, role: str):
+ print(f"[!OP:{role}]: ", end="", flush=True)
+
+ def stream_output(self, output_stream):
+ pre = 0
+ for outputs in output_stream:
+ output_text = outputs["text"]
+ output_text = output_text.strip().split(" ")
+ now = len(output_text) - 1
+ if now > pre:
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
+ pre = now
+ print(" ".join(output_text[pre:]), flush=True)
+ return " ".join(output_text)
+
+ def print_output(self, text: str):
+ print(text)
+
+
+def main(args):
+ if args.gpus:
+ if len(args.gpus.split(",")) < args.num_gpus:
+ raise ValueError(
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
+ )
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
+ os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
+ if args.enable_exllama:
+ exllama_config = ExllamaConfig(
+ max_seq_len=args.exllama_max_seq_len,
+ gpu_split=args.exllama_gpu_split,
+ )
+ else:
+ exllama_config = None
+ if args.enable_xft:
+ xft_config = XftConfig(
+ max_seq_len=args.xft_max_seq_len,
+ data_type=args.xft_dtype,
+ )
+ if args.device != "cpu":
+ print("xFasterTransformer now is only support CPUs. Reset device to CPU")
+ args.device = "cpu"
+ else:
+ xft_config = None
+ if args.style == "simple":
+ chatio = SimpleChatIO(args.multiline)
+ elif args.style == "rich":
+ chatio = RichChatIO(args.multiline, args.mouse)
+ elif args.style == "programmatic":
+ chatio = ProgrammaticChatIO()
+ else:
+ raise ValueError(f"Invalid style for console: {args.style}")
+ try:
+ if args.upload_file_path:
+ prefix = open(args.upload_file_path, 'r').read()
+ args.conv_system_msg = prefix[:20000]
+ chat_loop(
+ args.model_path,
+ args.device,
+ args.num_gpus,
+ args.max_gpu_memory,
+ str_to_torch_dtype(args.dtype),
+ args.load_8bit,
+ args.cpu_offloading,
+ args.conv_template,
+ args.conv_system_msg,
+ args.temperature,
+ args.repetition_penalty,
+ args.max_new_tokens,
+ chatio,
+ gptq_config=GptqConfig(
+ ckpt=args.gptq_ckpt or args.model_path,
+ wbits=args.gptq_wbits,
+ groupsize=args.gptq_groupsize,
+ act_order=args.gptq_act_order,
+ ),
+ awq_config=AWQConfig(
+ ckpt=args.awq_ckpt or args.model_path,
+ wbits=args.awq_wbits,
+ groupsize=args.awq_groupsize,
+ ),
+ exllama_config=exllama_config,
+ xft_config=xft_config,
+ revision=args.revision,
+ judge_sent_end=args.judge_sent_end,
+ debug=args.debug,
+ history=not args.no_history,
+ )
+ except KeyboardInterrupt:
+ print("exit...")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ add_model_args(parser)
+ parser.add_argument(
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
+ )
+ parser.add_argument(
+ "--conv-system-msg", type=str, default=None, help="Conversation system message."
+ )
+ parser.add_argument("--temperature", type=float, default=0.7)
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--no-history", action="store_true")
+ parser.add_argument(
+ "--style",
+ type=str,
+ default="simple",
+ choices=["simple", "rich", "programmatic"],
+ help="Display style.",
+ )
+ parser.add_argument(
+ "--multiline",
+ action="store_true",
+ help="Enable multiline input. Use ESC+Enter for newline.",
+ )
+ parser.add_argument(
+ "--mouse",
+ action="store_true",
+ help="[Rich Style]: Enable mouse support for cursor positioning.",
+ )
+ parser.add_argument(
+ "--judge-sent-end",
+ action="store_true",
+ help="Whether enable the correction logic that interrupts the output of sentences due to EOS.",
+ )
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ help="Print useful debug information (e.g., prompts)",
+ )
+ parser.add_argument(
+ "--upload-file-path",
+ type=str,
+ default="",
+ help="upload long txt for summary.",
+ )
+ args = parser.parse_args()
+ main(args)
diff --git a/fastchat/serve/controller.py b/fastchat/serve/controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..a67da62c42d898c17c23c0cc7244770cfeb78746
--- /dev/null
+++ b/fastchat/serve/controller.py
@@ -0,0 +1,348 @@
+"""
+A controller manages distributed workers.
+It sends worker addresses to clients.
+"""
+import argparse
+import asyncio
+import dataclasses
+from enum import Enum, auto
+import json
+import logging
+import os
+import time
+from typing import List, Union
+import threading
+
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
+import numpy as np
+import requests
+import uvicorn
+
+from fastchat.constants import (
+ CONTROLLER_HEART_BEAT_EXPIRATION,
+ WORKER_API_TIMEOUT,
+ ErrorCode,
+ SERVER_ERROR_MSG,
+)
+from fastchat.utils import build_logger
+
+
+logger = build_logger("controller", "controller.log")
+
+
+class DispatchMethod(Enum):
+ LOTTERY = auto()
+ SHORTEST_QUEUE = auto()
+
+ @classmethod
+ def from_str(cls, name):
+ if name == "lottery":
+ return cls.LOTTERY
+ elif name == "shortest_queue":
+ return cls.SHORTEST_QUEUE
+ else:
+ raise ValueError(f"Invalid dispatch method")
+
+
+@dataclasses.dataclass
+class WorkerInfo:
+ model_names: List[str]
+ speed: int
+ queue_length: int
+ check_heart_beat: bool
+ last_heart_beat: str
+
+
+def heart_beat_controller(controller):
+ while True:
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
+ controller.remove_stale_workers_by_expiration()
+
+
+class Controller:
+ def __init__(self, dispatch_method: str):
+ # Dict[str -> WorkerInfo]
+ self.worker_info = {}
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
+
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_controller, args=(self,)
+ )
+ self.heart_beat_thread.start()
+
+ def register_worker(
+ self, worker_name: str, check_heart_beat: bool, worker_status: dict
+ ):
+ if worker_name not in self.worker_info:
+ logger.info(f"Register a new worker: {worker_name}")
+ else:
+ logger.info(f"Register an existing worker: {worker_name}")
+
+ if not worker_status:
+ worker_status = self.get_worker_status(worker_name)
+ if not worker_status:
+ return False
+
+ self.worker_info[worker_name] = WorkerInfo(
+ worker_status["model_names"],
+ worker_status["speed"],
+ worker_status["queue_length"],
+ check_heart_beat,
+ time.time(),
+ )
+
+ logger.info(f"Register done: {worker_name}, {worker_status}")
+ return True
+
+ def get_worker_status(self, worker_name: str):
+ try:
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"Get status fails: {worker_name}, {e}")
+ return None
+
+ if r.status_code != 200:
+ logger.error(f"Get status fails: {worker_name}, {r}")
+ return None
+
+ return r.json()
+
+ def remove_worker(self, worker_name: str):
+ del self.worker_info[worker_name]
+
+ def refresh_all_workers(self):
+ old_info = dict(self.worker_info)
+ self.worker_info = {}
+
+ for w_name, w_info in old_info.items():
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
+ logger.info(f"Remove stale worker: {w_name}")
+
+ def list_models(self):
+ model_names = set()
+
+ for w_name, w_info in self.worker_info.items():
+ model_names.update(w_info.model_names)
+
+ return list(model_names)
+
+ def get_worker_address(self, model_name: str):
+ if self.dispatch_method == DispatchMethod.LOTTERY:
+ worker_names = []
+ worker_speeds = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_speeds.append(w_info.speed)
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ if True: # Directly return address
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
+ worker_name = worker_names[pt]
+ return worker_name
+
+ # Check status before returning
+ while True:
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
+ worker_name = worker_names[pt]
+
+ if self.get_worker_status(worker_name):
+ break
+ else:
+ self.remove_worker(worker_name)
+ worker_speeds[pt] = 0
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ continue
+ return worker_name
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
+ worker_names = []
+ worker_qlen = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_qlen.append(w_info.queue_length / w_info.speed)
+ if len(worker_names) == 0:
+ return ""
+ min_index = np.argmin(worker_qlen)
+ w_name = worker_names[min_index]
+ self.worker_info[w_name].queue_length += 1
+ logger.info(
+ f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
+ )
+ return w_name
+ else:
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
+
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
+ if worker_name not in self.worker_info:
+ logger.info(f"Receive unknown heart beat. {worker_name}")
+ return False
+
+ self.worker_info[worker_name].queue_length = queue_length
+ self.worker_info[worker_name].last_heart_beat = time.time()
+ logger.info(f"Receive heart beat. {worker_name}")
+ return True
+
+ def remove_stale_workers_by_expiration(self):
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
+ to_delete = []
+ for worker_name, w_info in self.worker_info.items():
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
+ to_delete.append(worker_name)
+
+ for worker_name in to_delete:
+ self.remove_worker(worker_name)
+
+ def handle_no_worker(self, params):
+ logger.info(f"no worker: {params['model']}")
+ ret = {
+ "text": SERVER_ERROR_MSG,
+ "error_code": ErrorCode.CONTROLLER_NO_WORKER,
+ }
+ return json.dumps(ret).encode() + b"\0"
+
+ def handle_worker_timeout(self, worker_address):
+ logger.info(f"worker timeout: {worker_address}")
+ ret = {
+ "text": SERVER_ERROR_MSG,
+ "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
+ }
+ return json.dumps(ret).encode() + b"\0"
+
+ # Let the controller act as a worker to achieve hierarchical
+ # management. This can be used to connect isolated sub networks.
+ def worker_api_get_status(self):
+ model_names = set()
+ speed = 0
+ queue_length = 0
+
+ for w_name in self.worker_info:
+ worker_status = self.get_worker_status(w_name)
+ if worker_status is not None:
+ model_names.update(worker_status["model_names"])
+ speed += worker_status["speed"]
+ queue_length += worker_status["queue_length"]
+
+ model_names = sorted(list(model_names))
+ return {
+ "model_names": model_names,
+ "speed": speed,
+ "queue_length": queue_length,
+ }
+
+ def worker_api_generate_stream(self, params):
+ worker_addr = self.get_worker_address(params["model"])
+ if not worker_addr:
+ yield self.handle_no_worker(params)
+
+ try:
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ json=params,
+ stream=True,
+ timeout=WORKER_API_TIMEOUT,
+ )
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ yield chunk + b"\0"
+ except requests.exceptions.RequestException as e:
+ yield self.handle_worker_timeout(worker_addr)
+
+
+app = FastAPI()
+
+
+@app.post("/register_worker")
+async def register_worker(request: Request):
+ data = await request.json()
+ controller.register_worker(
+ data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
+ )
+
+
+@app.post("/refresh_all_workers")
+async def refresh_all_workers():
+ models = controller.refresh_all_workers()
+
+
+@app.post("/list_models")
+async def list_models():
+ models = controller.list_models()
+ return {"models": models}
+
+
+@app.post("/get_worker_address")
+async def get_worker_address(request: Request):
+ data = await request.json()
+ addr = controller.get_worker_address(data["model"])
+ return {"address": addr}
+
+
+@app.post("/receive_heart_beat")
+async def receive_heart_beat(request: Request):
+ data = await request.json()
+ exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
+ return {"exist": exist}
+
+
+@app.post("/worker_generate_stream")
+async def worker_api_generate_stream(request: Request):
+ params = await request.json()
+ generator = controller.worker_api_generate_stream(params)
+ return StreamingResponse(generator)
+
+
+@app.post("/worker_get_status")
+async def worker_api_get_status(request: Request):
+ return controller.worker_api_get_status()
+
+
+@app.get("/test_connection")
+async def worker_api_get_status(request: Request):
+ return "success"
+
+
+def create_controller():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21001)
+ parser.add_argument(
+ "--dispatch-method",
+ type=str,
+ choices=["lottery", "shortest_queue"],
+ default="shortest_queue",
+ )
+ parser.add_argument(
+ "--ssl",
+ action="store_true",
+ required=False,
+ default=False,
+ help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
+ )
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ controller = Controller(args.dispatch_method)
+ return args, controller
+
+
+if __name__ == "__main__":
+ args, controller = create_controller()
+ if args.ssl:
+ uvicorn.run(
+ app,
+ host=args.host,
+ port=args.port,
+ log_level="info",
+ ssl_keyfile=os.environ["SSL_KEYFILE"],
+ ssl_certfile=os.environ["SSL_CERTFILE"],
+ )
+ else:
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/fastchat/serve/gateway/README.md b/fastchat/serve/gateway/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b3afaf171bc38b232b68609585244c9e76489da7
--- /dev/null
+++ b/fastchat/serve/gateway/README.md
@@ -0,0 +1,57 @@
+# fastchat Nginx Gateway
+
+## Purpose of the Gateway
+
+The Nginx gateway serves the following purposes:
+
+1. Protects Gradio servers by acting as a firewall.
+2. Facilitates dynamic mounting and unmounting of Gradio servers.
+3. Provides load balancing for Gradio servers.
+4. Offers additional security features, such as total connection limit.
+5. Reduces attack surface by requiring only a single public port to be exposed for serving.
+
+## Deployment and Updating of the Gateway
+
+### Installing Nginx
+
+On Debian-based distributions (e.g., Ubuntu):
+
+```bash
+sudo apt update
+sudo apt install nginx
+```
+On Red Hat-based distributions (e.g., CentOS, Fedora):
+
+```bash
+sudo yum install epel-release
+sudo yum install nginx
+```
+
+### Deployment
+
+Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission).
+
+Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server.
+
+Modify `upstream websocket` to configure Gradio servers behind the gateway.
+
+Lastly, update Nginx.
+
+
+### HTTPS Deployment with a Public Domain URL
+
+Make sure you obtain the HTTPS certificate and the private key used to generate the certificate.
+
+Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields.
+
+If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url.
+
+### Updating
+
+Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service:
+
+```bash
+sudo nginx -t # check `/etc/nginx/nginx.conf`
+sudo systemctl reload nginx # restart Nginx service to load the new config
+sudo systemctl status nginx # check the status of the Nginx service. It should be active (running).
+```
diff --git a/fastchat/serve/gateway/nginx.conf b/fastchat/serve/gateway/nginx.conf
new file mode 100644
index 0000000000000000000000000000000000000000..b88ca8c50772421fca91f33ff77ef75f4d23ad4d
--- /dev/null
+++ b/fastchat/serve/gateway/nginx.conf
@@ -0,0 +1,97 @@
+user www-data;
+worker_processes auto;
+pid /run/nginx.pid;
+include /etc/nginx/modules-enabled/*.conf;
+
+events {
+ worker_connections 1024; # maximum number of connections that a worker process can handle concurrently
+ # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle
+
+}
+
+http {
+ ##
+ # Basic Settings
+ ##
+
+ sendfile on; # enable sendfile for performance optimization
+ tcp_nopush on; # enable TCP no-pushing
+ tcp_nodelay on; # enable TCP no-delay
+ keepalive_timeout 65; # sets the timeout for keep-alive connections
+ types_hash_max_size 2048; # maximum size of the types hash table
+ # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security
+
+ # server_names_hash_bucket_size 64;
+ # server_name_in_redirect off;
+
+ include /etc/nginx/mime.types; # include MIME types file
+ default_type application/octet-stream; # default MIME type for unknown file types
+
+ ##
+ # SSL Settings
+ ##
+
+ ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use
+ ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers
+
+ ##
+ # Logging Settings
+ ##
+
+ access_log /var/log/nginx/access.log; # path to access log file
+ error_log /var/log/nginx/error.log; # path to error log file
+
+ ##
+ # Gzip Settings
+ ##
+ gzip on; # enable Gzip compression
+
+ ##
+ # Virtual Host Configs
+ ##
+
+ include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory
+ include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files
+
+ # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/
+ map $http_upgrade $connection_upgrade {
+ default upgrade;
+ '' close;
+ }
+
+ upstream websocket {
+ ip_hash; # load balancing by IP to guarantee session persistence
+ server localhost:7860; # The port should be the gradio web server port
+ # server localhost:7861; # extra gradio server if more than one
+ }
+
+ limit_conn_status 429;
+ limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP
+ limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server
+
+ server {
+ listen 443 ssl; # the listening port of our server
+ ssl_certificate [PATH_TO_SSL_CERT];
+ ssl_certificate_key [PATH_TO_PRIVATE_KEY];
+ server_name chat.lmsys.org; # replace the url with your own domain url
+ limit_conn perserver 1024; # connections per server
+ location / {
+ proxy_pass http://websocket; # proxy all requests to the defined upstream server
+ limit_conn perip 5; # connections per IP
+ proxy_set_header Host $host; # set the Host header for the upstream server
+ proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header
+ proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication
+ }
+ }
+
+ # the following block routes all HTTP traffic to HTTPS via nginx
+ server {
+ listen 80;
+ server_name chat.lmsys.org;
+ return 301 https://chat.lmsys.org$request_uri;
+ }
+
+}
diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py
new file mode 100644
index 0000000000000000000000000000000000000000..48e49deef8818595ec8e8104972955d6be889b5f
--- /dev/null
+++ b/fastchat/serve/gradio_block_arena_anony.py
@@ -0,0 +1,608 @@
+"""
+Chatbot Arena (battle) tab.
+Users chat with two anonymous models.
+"""
+
+import json
+import time
+
+import gradio as gr
+import numpy as np
+
+from fastchat.constants import (
+ MODERATION_MSG,
+ CONVERSATION_LIMIT_MSG,
+ SLOW_MODEL_MSG,
+ INPUT_CHAR_LEN_LIMIT,
+ CONVERSATION_TURN_LIMIT,
+)
+from fastchat.model.model_adapter import get_conversation_template
+from fastchat.serve.gradio_block_arena_named import flash_buttons
+from fastchat.serve.gradio_web_server import (
+ State,
+ bot_response,
+ get_conv_log_filename,
+ no_change_btn,
+ enable_btn,
+ disable_btn,
+ invisible_btn,
+ acknowledgment_md,
+ ip_expiration_dict,
+ get_ip,
+)
+from fastchat.utils import (
+ build_logger,
+ moderation_filter,
+)
+
+logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
+
+num_sides = 2
+enable_moderation = False
+anony_names = ["", ""]
+models = []
+
+
+def set_global_vars_anony(enable_moderation_):
+ global enable_moderation
+ enable_moderation = enable_moderation_
+
+
+def load_demo_side_by_side_anony(models_, url_params):
+ global models
+ models = models_
+
+ states = (None,) * num_sides
+ selector_updates = (
+ gr.Markdown.update(visible=True),
+ gr.Markdown.update(visible=True),
+ )
+
+ return states + selector_updates
+
+
+def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "models": [x for x in model_selectors],
+ "states": [x.dict() for x in states],
+ "ip": get_ip(request),
+ }
+ fout.write(json.dumps(data) + "\n")
+
+ if ":" not in model_selectors[0]:
+ for i in range(15):
+ names = (
+ "### Model A: " + states[0].model_name,
+ "### Model B: " + states[1].model_name,
+ )
+ yield names + ("",) + (disable_btn,) * 4
+ time.sleep(0.2)
+ else:
+ names = (
+ "### Model A: " + states[0].model_name,
+ "### Model B: " + states[1].model_name,
+ )
+ yield names + ("",) + (disable_btn,) * 4
+
+
+def leftvote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"leftvote (anony). ip: {get_ip(request)}")
+ for x in vote_last_response(
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
+ ):
+ yield x
+
+
+def rightvote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"rightvote (anony). ip: {get_ip(request)}")
+ for x in vote_last_response(
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
+ ):
+ yield x
+
+
+def tievote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"tievote (anony). ip: {get_ip(request)}")
+ for x in vote_last_response(
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
+ ):
+ yield x
+
+
+def bothbad_vote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
+ for x in vote_last_response(
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
+ ):
+ yield x
+
+
+def regenerate(state0, state1, request: gr.Request):
+ logger.info(f"regenerate (anony). ip: {get_ip(request)}")
+ states = [state0, state1]
+ for i in range(num_sides):
+ states[i].conv.update_last_message(None)
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history (anony). ip: {get_ip(request)}")
+ return (
+ [None] * num_sides
+ + [None] * num_sides
+ + anony_names
+ + [""]
+ + [invisible_btn] * 4
+ + [disable_btn] * 2
+ + [""]
+ )
+
+
+def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
+ logger.info(f"share (anony). ip: {get_ip(request)}")
+ if state0 is not None and state1 is not None:
+ vote_last_response(
+ [state0, state1], "share", [model_selector0, model_selector1], request
+ )
+
+
+SAMPLING_WEIGHTS = {
+ # tier 0
+ "gpt-4": 4,
+ "gpt-4-turbo": 4,
+ "gpt-3.5-turbo": 2,
+ "gpt-3.5-turbo-1106": 2,
+ "claude-2": 8,
+ "claude-1": 2,
+ "claude-instant-1": 8,
+ "zephyr-7b-beta": 2,
+ "openchat-3.5": 2,
+ # tier 1
+ "deluxe-chat-v1.1": 2,
+ "palm-2": 1.5,
+ "llama-2-70b-chat": 1.5,
+ "llama-2-13b-chat": 1.5,
+ "codellama-34b-instruct": 1.5,
+ "vicuna-33b": 8,
+ "vicuna-13b": 1.5,
+ "wizardlm-70b": 1.5,
+ "wizardlm-13b": 1.5,
+ "qwen-14b-chat": 1.5,
+ "mistral-7b-instruct": 1.5,
+ # tier 2
+ "vicuna-7b": 1.0,
+ "llama-2-7b-chat": 1.0,
+ "chatglm2-6b": 1.0,
+ # deprecated
+ "zephyr-7b-alpha": 1.5,
+ "codellama-13b-instruct": 1.0,
+ "mpt-30b-chat": 1.5,
+ "guanaco-33b": 1.0,
+ "fastchat-t5-3b": 0.5,
+ "alpaca-13b": 0.5,
+ "mpt-7b-chat": 0.1,
+ "oasst-pythia-12b": 0.1,
+ "RWKV-4-Raven-14B": 0.1,
+ "gpt4all-13b-snoozy": 0.1,
+ "koala-13b": 0.1,
+ "stablelm-tuned-alpha-7b": 0.1,
+ "dolly-v2-12b": 0.1,
+ "llama-13b": 0.1,
+ "chatglm-6b": 0.5,
+ "deluxe-chat-v1": 4,
+}
+
+# target model sampling weights will be boosted.
+BATTLE_TARGETS = {
+ "gpt-4": {"claude-2"},
+ "gpt-4-turbo": {"gpt-4", "gpt-3.5-turbo"},
+ "gpt-3.5-turbo": {"claude-instant-1", "gpt-4", "claude-2"},
+ "claude-2": {"gpt-4", "gpt-3.5-turbo", "claude-1"},
+ "claude-1": {"claude-2", "gpt-4", "gpt-3.5-turbo"},
+ "claude-instant-1": {"gpt-3.5-turbo", "claude-2"},
+ "deluxe-chat-v1.1": {"gpt-4"},
+ "openchat-3.5": {"gpt-3.5-turbo", "llama-2-70b-chat", "zephyr-7b-beta"},
+ "qwen-14b-chat": {"vicuna-13b", "llama-2-13b-chat", "llama-2-70b-chat"},
+ "zephyr-7b-alpha": {"mistral-7b-instruct", "llama-2-13b-chat"},
+ "zephyr-7b-beta": {
+ "mistral-7b-instruct",
+ "llama-2-13b-chat",
+ "llama-2-7b-chat",
+ "wizardlm-13b",
+ },
+ "llama-2-70b-chat": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"},
+ "llama-2-13b-chat": {"mistral-7b-instruct", "vicuna-13b", "llama-2-70b-chat"},
+ "llama-2-7b-chat": {"mistral-7b-instruct", "vicuna-7b", "llama-2-13b-chat"},
+ "mistral-7b-instruct": {
+ "llama-2-7b-chat",
+ "llama-2-13b-chat",
+ "llama-2-70b-chat",
+ },
+ "vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo", "claude-instant-1"},
+ "vicuna-13b": {"llama-2-13b-chat", "llama-2-70b-chat"},
+ "vicuna-7b": {"llama-2-7b-chat", "mistral-7b-instruct", "llama-2-13b-chat"},
+ "wizardlm-70b": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"},
+ "palm-2": {"llama-2-13b-chat", "gpt-3.5-turbo"},
+}
+
+SAMPLING_BOOST_MODELS = ["openchat-3.5", "gpt-4-turbo", "gpt-3.5-turbo-1106"]
+
+# outage models won't be sampled.
+OUTAGE_MODELS = []
+
+
+def get_sample_weight(model):
+ if model in OUTAGE_MODELS:
+ return 0
+ weight = SAMPLING_WEIGHTS.get(model, 1.0)
+ if model in SAMPLING_BOOST_MODELS:
+ weight *= 5
+ return weight
+
+
+def get_battle_pair():
+ if len(models) == 1:
+ return models[0], models[0]
+
+ model_weights = []
+ for model in models:
+ weight = get_sample_weight(model)
+ model_weights.append(weight)
+ total_weight = np.sum(model_weights)
+ model_weights = model_weights / total_weight
+ chosen_idx = np.random.choice(len(models), p=model_weights)
+ chosen_model = models[chosen_idx]
+
+ rival_models = []
+ rival_weights = []
+ for model in models:
+ if model == chosen_model:
+ continue
+ weight = get_sample_weight(model)
+ if (
+ weight != 0
+ and chosen_model in BATTLE_TARGETS
+ and model in BATTLE_TARGETS[chosen_model]
+ ):
+ # boost to 50% chance
+ weight = total_weight / len(BATTLE_TARGETS[chosen_model])
+ rival_models.append(model)
+ rival_weights.append(weight)
+ # for p, w in zip(rival_models, rival_weights):
+ # print(p, w)
+ rival_weights = rival_weights / np.sum(rival_weights)
+ rival_idx = np.random.choice(len(rival_models), p=rival_weights)
+ rival_model = rival_models[rival_idx]
+
+ swap = np.random.randint(2)
+ if swap == 0:
+ return chosen_model, rival_model
+ else:
+ return rival_model, chosen_model
+
+
+def add_text(
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
+):
+ ip = get_ip(request)
+ logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
+ states = [state0, state1]
+ model_selectors = [model_selector0, model_selector1]
+
+ # Init states if necessary
+ if states[0] is None:
+ assert states[1] is None
+
+ model_left, model_right = get_battle_pair()
+ states = [
+ State(model_left),
+ State(model_right),
+ ]
+
+ if len(text) <= 0:
+ for i in range(num_sides):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ + [""]
+ )
+
+ model_list = [states[i].model_name for i in range(num_sides)]
+ flagged = moderation_filter(text, model_list)
+ if flagged:
+ logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
+ # overwrite the original text
+ text = MODERATION_MSG
+
+ conv = states[0].conv
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
+ logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
+ for i in range(num_sides):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [CONVERSATION_LIMIT_MSG]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ + [""]
+ )
+
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
+ for i in range(num_sides):
+ states[i].conv.append_message(states[i].conv.roles[0], text)
+ states[i].conv.append_message(states[i].conv.roles[1], None)
+ states[i].skip_next = False
+
+ slow_model_msg = ""
+ for i in range(num_sides):
+ if "deluxe" in states[i].model_name:
+ slow_model_msg = SLOW_MODEL_MSG
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ disable_btn,
+ ]
+ * 6
+ + [slow_model_msg]
+ )
+
+
+def bot_response_multi(
+ state0,
+ state1,
+ temperature,
+ top_p,
+ max_new_tokens,
+ request: gr.Request,
+):
+ logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
+
+ if state0 is None or state0.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (
+ state0,
+ state1,
+ state0.to_gradio_chatbot(),
+ state1.to_gradio_chatbot(),
+ ) + (no_change_btn,) * 6
+ return
+
+ states = [state0, state1]
+ gen = []
+ for i in range(num_sides):
+ gen.append(
+ bot_response(
+ states[i],
+ temperature,
+ top_p,
+ max_new_tokens,
+ request,
+ )
+ )
+
+ chatbots = [None] * num_sides
+ while True:
+ stop = True
+ for i in range(num_sides):
+ try:
+ ret = next(gen[i])
+ states[i], chatbots[i] = ret[0], ret[1]
+ stop = False
+ except StopIteration:
+ pass
+ yield states + chatbots + [disable_btn] * 6
+ if stop:
+ break
+
+
+def build_side_by_side_ui_anony(models):
+ notice_markdown = """
+# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild
+| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
+
+## 📜 Rules
+- Ask any question to two anonymous models (e.g., ChatGPT, Claude, Llama) and vote for the better one!
+- You can continue chatting until you identify a winner.
+- Vote won't be counted if model identity is revealed during conversation.
+
+## 🏆 Arena Elo [Leaderboard](https://huggingface.co./spaces/lmsys/chatbot-arena-leaderboard)
+We use **100K** human votes to compile an Elo-based LLM leaderboard.
+Find out who is the 🥇LLM Champion!
+
+## 👇 Chat now!
+
+"""
+
+ states = [gr.State() for _ in range(num_sides)]
+ model_selectors = [None] * num_sides
+ chatbots = [None] * num_sides
+
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
+
+ with gr.Box(elem_id="share-region-anony"):
+ with gr.Row():
+ for i in range(num_sides):
+ label = "Model A" if i == 0 else "Model B"
+ with gr.Column():
+ chatbots[i] = gr.Chatbot(
+ label=label, elem_id=f"chatbot", height=550
+ )
+
+ with gr.Row():
+ for i in range(num_sides):
+ with gr.Column():
+ model_selectors[i] = gr.Markdown(anony_names[i])
+ with gr.Row():
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
+
+ with gr.Row():
+ leftvote_btn = gr.Button(
+ value="👈 A is better", visible=False, interactive=False
+ )
+ rightvote_btn = gr.Button(
+ value="👉 B is better", visible=False, interactive=False
+ )
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
+ bothbad_btn = gr.Button(
+ value="👎 Both are bad", visible=False, interactive=False
+ )
+
+ with gr.Row():
+ with gr.Column(scale=20):
+ textbox = gr.Textbox(
+ show_label=False,
+ placeholder="👉 Enter your prompt and press ENTER",
+ container=False,
+ elem_id="input_box",
+ )
+ with gr.Column(scale=1, min_width=50):
+ send_btn = gr.Button(value="Send", variant="primary")
+
+ with gr.Row() as button_row:
+ clear_btn = gr.Button(value="🎲 New Round", interactive=False)
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ share_btn = gr.Button(value="📷 Share")
+
+ with gr.Accordion("Parameters", open=False) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.7,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Top P",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=16,
+ maximum=1024,
+ value=512,
+ step=64,
+ interactive=True,
+ label="Max output tokens",
+ )
+
+ gr.Markdown(acknowledgment_md)
+
+ # Register listeners
+ btn_list = [
+ leftvote_btn,
+ rightvote_btn,
+ tie_btn,
+ bothbad_btn,
+ regenerate_btn,
+ clear_btn,
+ ]
+ leftvote_btn.click(
+ leftvote_last_response,
+ states + model_selectors,
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ rightvote_btn.click(
+ rightvote_last_response,
+ states + model_selectors,
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ tie_btn.click(
+ tievote_last_response,
+ states + model_selectors,
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ bothbad_btn.click(
+ bothbad_vote_last_response,
+ states + model_selectors,
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ regenerate_btn.click(
+ regenerate, states, states + chatbots + [textbox] + btn_list
+ ).then(
+ bot_response_multi,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ ).then(
+ flash_buttons, [], btn_list
+ )
+ clear_btn.click(
+ clear_history,
+ None,
+ states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning],
+ )
+
+ share_js = """
+function (a, b, c, d) {
+ const captureElement = document.querySelector('#share-region-anony');
+ html2canvas(captureElement)
+ .then(canvas => {
+ canvas.style.display = 'none'
+ document.body.appendChild(canvas)
+ return canvas
+ })
+ .then(canvas => {
+ const image = canvas.toDataURL('image/png')
+ const a = document.createElement('a')
+ a.setAttribute('download', 'chatbot-arena.png')
+ a.setAttribute('href', image)
+ a.click()
+ canvas.remove()
+ });
+ return [a, b, c, d];
+}
+"""
+ share_btn.click(share_click, states + model_selectors, [], _js=share_js)
+
+ textbox.submit(
+ add_text,
+ states + model_selectors + [textbox],
+ states + chatbots + [textbox] + btn_list + [slow_warning],
+ ).then(
+ bot_response_multi,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ ).then(
+ flash_buttons,
+ [],
+ btn_list,
+ )
+
+ send_btn.click(
+ add_text,
+ states + model_selectors + [textbox],
+ states + chatbots + [textbox] + btn_list,
+ ).then(
+ bot_response_multi,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ ).then(
+ flash_buttons, [], btn_list
+ )
+
+ return states + model_selectors
diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py
new file mode 100644
index 0000000000000000000000000000000000000000..c13283495aec8262de92d47fbd758e47204172d0
--- /dev/null
+++ b/fastchat/serve/gradio_block_arena_named.py
@@ -0,0 +1,458 @@
+"""
+Chatbot Arena (side-by-side) tab.
+Users chat with two chosen models.
+"""
+
+import json
+import time
+
+import gradio as gr
+import numpy as np
+
+from fastchat.constants import (
+ MODERATION_MSG,
+ CONVERSATION_LIMIT_MSG,
+ INPUT_CHAR_LEN_LIMIT,
+ CONVERSATION_TURN_LIMIT,
+)
+from fastchat.model.model_adapter import get_conversation_template
+from fastchat.serve.gradio_web_server import (
+ State,
+ bot_response,
+ get_conv_log_filename,
+ no_change_btn,
+ enable_btn,
+ disable_btn,
+ invisible_btn,
+ acknowledgment_md,
+ get_model_description_md,
+ ip_expiration_dict,
+ get_ip,
+)
+from fastchat.utils import (
+ build_logger,
+ moderation_filter,
+)
+
+
+logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
+
+num_sides = 2
+enable_moderation = False
+
+
+def set_global_vars_named(enable_moderation_):
+ global enable_moderation
+ enable_moderation = enable_moderation_
+
+
+def load_demo_side_by_side_named(models, url_params):
+ states = (None,) * num_sides
+
+ model_left = models[0] if len(models) > 0 else ""
+ if len(models) > 1:
+ weights = ([8] * 4 + [4] * 8 + [1] * 32)[: len(models) - 1]
+ weights = weights / np.sum(weights)
+ model_right = np.random.choice(models[1:], p=weights)
+ else:
+ model_right = model_left
+
+ selector_updates = (
+ gr.Dropdown.update(choices=models, value=model_left, visible=True),
+ gr.Dropdown.update(choices=models, value=model_right, visible=True),
+ )
+
+ return states + selector_updates
+
+
+def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "models": [x for x in model_selectors],
+ "states": [x.dict() for x in states],
+ "ip": get_ip(request),
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+def leftvote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"leftvote (named). ip: {get_ip(request)}")
+ vote_last_response(
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
+ )
+ return ("",) + (disable_btn,) * 4
+
+
+def rightvote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"rightvote (named). ip: {get_ip(request)}")
+ vote_last_response(
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
+ )
+ return ("",) + (disable_btn,) * 4
+
+
+def tievote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"tievote (named). ip: {get_ip(request)}")
+ vote_last_response(
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
+ )
+ return ("",) + (disable_btn,) * 4
+
+
+def bothbad_vote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
+ vote_last_response(
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
+ )
+ return ("",) + (disable_btn,) * 4
+
+
+def regenerate(state0, state1, request: gr.Request):
+ logger.info(f"regenerate (named). ip: {get_ip(request)}")
+ states = [state0, state1]
+ for i in range(num_sides):
+ states[i].conv.update_last_message(None)
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history (named). ip: {get_ip(request)}")
+ return (
+ [None] * num_sides
+ + [None] * num_sides
+ + [""]
+ + [invisible_btn] * 4
+ + [disable_btn] * 2
+ )
+
+
+def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
+ logger.info(f"share (named). ip: {get_ip(request)}")
+ if state0 is not None and state1 is not None:
+ vote_last_response(
+ [state0, state1], "share", [model_selector0, model_selector1], request
+ )
+
+
+def add_text(
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
+):
+ ip = get_ip(request)
+ logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
+ states = [state0, state1]
+ model_selectors = [model_selector0, model_selector1]
+
+ # Init states if necessary
+ for i in range(num_sides):
+ if states[i] is None:
+ states[i] = State(model_selectors[i])
+
+ if len(text) <= 0:
+ for i in range(num_sides):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ )
+
+ model_list = [states[i].model_name for i in range(num_sides)]
+ flagged = moderation_filter(text, model_list)
+ if flagged:
+ logger.info(f"violate moderation (named). ip: {ip}. text: {text}")
+ # overwrite the original text
+ text = MODERATION_MSG
+
+ conv = states[0].conv
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
+ for i in range(num_sides):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [CONVERSATION_LIMIT_MSG]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ )
+
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
+ for i in range(num_sides):
+ states[i].conv.append_message(states[i].conv.roles[0], text)
+ states[i].conv.append_message(states[i].conv.roles[1], None)
+ states[i].skip_next = False
+
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ disable_btn,
+ ]
+ * 6
+ )
+
+
+def bot_response_multi(
+ state0,
+ state1,
+ temperature,
+ top_p,
+ max_new_tokens,
+ request: gr.Request,
+):
+ logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
+
+ if state0.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (
+ state0,
+ state1,
+ state0.to_gradio_chatbot(),
+ state1.to_gradio_chatbot(),
+ ) + (no_change_btn,) * 6
+ return
+
+ states = [state0, state1]
+ gen = []
+ for i in range(num_sides):
+ gen.append(
+ bot_response(
+ states[i],
+ temperature,
+ top_p,
+ max_new_tokens,
+ request,
+ )
+ )
+
+ chatbots = [None] * num_sides
+ while True:
+ stop = True
+ for i in range(num_sides):
+ try:
+ ret = next(gen[i])
+ states[i], chatbots[i] = ret[0], ret[1]
+ stop = False
+ except StopIteration:
+ pass
+ yield states + chatbots + [disable_btn] * 6
+ if stop:
+ break
+
+
+def flash_buttons():
+ btn_updates = [
+ [disable_btn] * 4 + [enable_btn] * 2,
+ [enable_btn] * 6,
+ ]
+ for i in range(4):
+ yield btn_updates[i % 2]
+ time.sleep(0.5)
+
+
+def build_side_by_side_ui_named(models):
+ notice_markdown = """
+# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild
+| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
+
+## 📜 Rules
+- Chat with any two models side-by-side and vote!
+- You can continue chatting for multiple rounds.
+- Click "Clear history" to start a new round.
+
+## 🤖 Choose two models to compare
+"""
+
+ states = [gr.State() for _ in range(num_sides)]
+ model_selectors = [None] * num_sides
+ chatbots = [None] * num_sides
+
+ model_description_md = get_model_description_md(models)
+ notice = gr.Markdown(
+ notice_markdown + model_description_md, elem_id="notice_markdown"
+ )
+
+ with gr.Box(elem_id="share-region-named"):
+ with gr.Row():
+ for i in range(num_sides):
+ with gr.Column():
+ model_selectors[i] = gr.Dropdown(
+ choices=models,
+ value=models[i] if len(models) > i else "",
+ interactive=True,
+ show_label=False,
+ container=False,
+ )
+
+ with gr.Row():
+ for i in range(num_sides):
+ label = "Model A" if i == 0 else "Model B"
+ with gr.Column():
+ chatbots[i] = gr.Chatbot(
+ label=label, elem_id=f"chatbot", height=550
+ )
+
+ with gr.Row():
+ leftvote_btn = gr.Button(
+ value="👈 A is better", visible=False, interactive=False
+ )
+ rightvote_btn = gr.Button(
+ value="👉 B is better", visible=False, interactive=False
+ )
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
+ bothbad_btn = gr.Button(
+ value="👎 Both are bad", visible=False, interactive=False
+ )
+
+ with gr.Row():
+ with gr.Column(scale=20):
+ textbox = gr.Textbox(
+ show_label=False,
+ placeholder="Enter your prompt here and press ENTER",
+ container=False,
+ elem_id="input_box",
+ )
+ with gr.Column(scale=1, min_width=50):
+ send_btn = gr.Button(value="Send", variant="primary")
+
+ with gr.Row() as button_row:
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
+ share_btn = gr.Button(value="📷 Share")
+
+ with gr.Accordion("Parameters", open=False) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.7,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Top P",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=16,
+ maximum=1024,
+ value=512,
+ step=64,
+ interactive=True,
+ label="Max output tokens",
+ )
+
+ gr.Markdown(acknowledgment_md)
+
+ # Register listeners
+ btn_list = [
+ leftvote_btn,
+ rightvote_btn,
+ tie_btn,
+ bothbad_btn,
+ regenerate_btn,
+ clear_btn,
+ ]
+ leftvote_btn.click(
+ leftvote_last_response,
+ states + model_selectors,
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ rightvote_btn.click(
+ rightvote_last_response,
+ states + model_selectors,
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ tie_btn.click(
+ tievote_last_response,
+ states + model_selectors,
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ bothbad_btn.click(
+ bothbad_vote_last_response,
+ states + model_selectors,
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ regenerate_btn.click(
+ regenerate, states, states + chatbots + [textbox] + btn_list
+ ).then(
+ bot_response_multi,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ ).then(
+ flash_buttons, [], btn_list
+ )
+ clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list)
+
+ share_js = """
+function (a, b, c, d) {
+ const captureElement = document.querySelector('#share-region-named');
+ html2canvas(captureElement)
+ .then(canvas => {
+ canvas.style.display = 'none'
+ document.body.appendChild(canvas)
+ return canvas
+ })
+ .then(canvas => {
+ const image = canvas.toDataURL('image/png')
+ const a = document.createElement('a')
+ a.setAttribute('download', 'chatbot-arena.png')
+ a.setAttribute('href', image)
+ a.click()
+ canvas.remove()
+ });
+ return [a, b, c, d];
+}
+"""
+ share_btn.click(share_click, states + model_selectors, [], _js=share_js)
+
+ for i in range(num_sides):
+ model_selectors[i].change(
+ clear_history, None, states + chatbots + [textbox] + btn_list
+ )
+
+ textbox.submit(
+ add_text,
+ states + model_selectors + [textbox],
+ states + chatbots + [textbox] + btn_list,
+ ).then(
+ bot_response_multi,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ ).then(
+ flash_buttons, [], btn_list
+ )
+ send_btn.click(
+ add_text,
+ states + model_selectors + [textbox],
+ states + chatbots + [textbox] + btn_list,
+ ).then(
+ bot_response_multi,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ ).then(
+ flash_buttons, [], btn_list
+ )
+
+ return states + model_selectors
diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..adab6e9063f6b60b8f446cc6df009a643fc910ad
--- /dev/null
+++ b/fastchat/serve/gradio_web_server.py
@@ -0,0 +1,883 @@
+"""
+The gradio demo server for chatting with a single model.
+"""
+
+import argparse
+from collections import defaultdict
+import datetime
+import json
+import os
+import random
+import time
+import uuid
+
+import gradio as gr
+import requests
+
+from fastchat.conversation import SeparatorStyle
+from fastchat.constants import (
+ LOGDIR,
+ WORKER_API_TIMEOUT,
+ ErrorCode,
+ MODERATION_MSG,
+ CONVERSATION_LIMIT_MSG,
+ SERVER_ERROR_MSG,
+ INPUT_CHAR_LEN_LIMIT,
+ CONVERSATION_TURN_LIMIT,
+ SESSION_EXPIRATION_TIME,
+)
+from fastchat.model.model_adapter import get_conversation_template
+from fastchat.conversation import get_conv_template
+from fastchat.model.model_registry import get_model_info, model_info
+from fastchat.serve.api_provider import (
+ anthropic_api_stream_iter,
+ openai_api_stream_iter,
+ palm_api_stream_iter,
+ init_palm_chat,
+)
+from fastchat.utils import (
+ build_logger,
+ moderation_filter,
+ get_window_url_params_js,
+ get_window_url_params_with_tos_js,
+ parse_gradio_auth_creds,
+)
+
+CONV_TEMPLATE = ''
+
+logger = build_logger("gradio_web_server", "gradio_web_server.log")
+
+headers = {"User-Agent": "FastChat Client"}
+
+no_change_btn = gr.Button.update()
+enable_btn = gr.Button.update(interactive=True, visible=True)
+disable_btn = gr.Button.update(interactive=False)
+invisible_btn = gr.Button.update(interactive=False, visible=False)
+
+controller_url = None
+enable_moderation = False
+
+acknowledgment_md = """
+### Acknowledgment
+
+"""
+
+ip_expiration_dict = defaultdict(lambda: 0)
+
+# Information about custom OpenAI compatible API models.
+# JSON file format:
+# {
+# "vicuna-7b": {
+# "model_name": "vicuna-7b-v1.5",
+# "api_base": "http://8.8.8.55:5555/v1",
+# "api_key": "password"
+# },
+# }
+openai_compatible_models_info = {}
+
+
+class State:
+ def __init__(self, model_name):
+ # if model_name=='checkpoint-800':
+ # self.conv = get_conv_template(CONV_TEMPLATE)
+ # elif model_name=='MiniCPM-2B-sft-bf16':
+ ret = requests.post(
+ controller_url + "/get_worker_address", json={"model": model_name}
+ )
+ worker_addr = ret.json()["address"]
+ conv_name = requests.post(
+ worker_addr + "/worker_get_conv_template",
+ ).json()['conv']['name']
+ self.conv = get_conv_template(conv_name)
+ # self.conv = get_conv_template('minicpm')
+ # print(self.conv)
+ # self.conv = get_conversation_template(model_name)
+ self.conv_id = uuid.uuid4().hex
+ self.skip_next = False
+ self.model_name = model_name
+
+ if model_name == "palm-2":
+ # According to release note, "chat-bison@001" is PaLM 2 for chat.
+ # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023
+ self.palm_chat = init_palm_chat("chat-bison@001")
+
+ def to_gradio_chatbot(self):
+ return self.conv.to_gradio_chatbot()
+
+ def dict(self):
+ base = self.conv.dict()
+ base.update(
+ {
+ "conv_id": self.conv_id,
+ "model_name": self.model_name,
+ }
+ )
+ return base
+
+
+def set_global_vars(controller_url_, enable_moderation_):
+ global controller_url, enable_moderation
+ controller_url = controller_url_
+ enable_moderation = enable_moderation_
+
+
+def get_conv_log_filename():
+ t = datetime.datetime.now()
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
+ return name
+
+
+def get_model_list(
+ controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm
+):
+ if controller_url:
+ ret = requests.post(controller_url + "/refresh_all_workers")
+ assert ret.status_code == 200
+ ret = requests.post(controller_url + "/list_models")
+ # ret = requests.post(controller_url + "/get_worker_address")
+ # ret = requests.post(controller_url + "/worker_get_status")
+ models = ret.json()["models"]
+ else:
+ models = []
+
+ # Add API providers
+ if register_openai_compatible_models:
+ global openai_compatible_models_info
+ openai_compatible_models_info = json.load(
+ open(register_openai_compatible_models)
+ )
+ models += list(openai_compatible_models_info.keys())
+
+ if add_chatgpt:
+ models += ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"]
+ if add_claude:
+ models += ["claude-2", "claude-instant-1"]
+ if add_palm:
+ models += ["palm-2"]
+ models = list(set(models))
+
+ if "deluxe-chat-v1" in models:
+ del models[models.index("deluxe-chat-v1")]
+ if "deluxe-chat-v1.1" in models:
+ del models[models.index("deluxe-chat-v1.1")]
+
+ priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
+ models.sort(key=lambda x: priority.get(x, x))
+ logger.info(f"Models: {models}")
+ return models
+
+
+def load_demo_single(models, url_params):
+ selected_model = models[0] if len(models) > 0 else ""
+ if "model" in url_params:
+ model = url_params["model"]
+ if model in models:
+ selected_model = model
+
+ dropdown_update = gr.Dropdown.update(
+ choices=models, value=selected_model, visible=True
+ )
+
+ state = None
+ return state, dropdown_update
+
+
+def load_demo(url_params, request: gr.Request):
+ global models
+
+ ip = get_ip(request)
+ logger.info(f"load_demo. ip: {ip}. params: {url_params}")
+ ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME
+
+ if args.model_list_mode == "reload":
+ models = get_model_list(
+ controller_url,
+ args.register_openai_compatible_models,
+ args.add_chatgpt,
+ args.add_claude,
+ args.add_palm,
+ )
+
+ return load_demo_single(models, url_params)
+
+
+def vote_last_response(state, vote_type, model_selector, request: gr.Request):
+ with open('./web_chat_downvote.jsonl', "a+") as fout:
+ # data = {
+ # "tstamp": round(time.time(), 4),
+ # "type": vote_type,
+ # "model": model_selector,
+ # "state": state.dict(),
+ # "ip": get_ip(request),
+ # }
+ conversations = []
+ for i, turn in enumerate(state.dict()['messages']):
+ role = 'user' if i % 2 == 0 else 'assistant'
+ conversations.append({'role': role, 'content': turn[1]})
+ data = {
+ 'conversations': conversations,
+ 'idx': state.dict()['conv_id'],
+ 'tinder': 'badcase',
+ 'model': state.dict()['model_name'],
+ 'tokens_in': -1,
+ 'tokens_out': -1,
+ }
+ fout.write(json.dumps(data, ensure_ascii=False) + "\n")
+
+
+def upvote_last_response(state, model_selector, request: gr.Request):
+ ip = get_ip(request)
+ logger.info(f"upvote. ip: {ip}")
+ vote_last_response(state, "upvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def downvote_last_response(state, model_selector, request: gr.Request):
+ ip = get_ip(request)
+ logger.info(f"downvote. ip: {ip}")
+ vote_last_response(state, "downvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def flag_last_response(state, model_selector, request: gr.Request):
+ ip = get_ip(request)
+ logger.info(f"flag. ip: {ip}")
+ vote_last_response(state, "flag", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def regenerate(state, request: gr.Request):
+ ip = get_ip(request)
+ logger.info(f"regenerate. ip: {ip}")
+ state.conv.update_last_message(None)
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
+
+
+def clear_history(request: gr.Request):
+ ip = get_ip(request)
+ logger.info(f"clear_history. ip: {ip}")
+ state = None
+ return (state, [], "") + (disable_btn,) * 5
+
+
+def get_ip(request: gr.Request):
+ if "cf-connecting-ip" in request.headers:
+ ip = request.headers["cf-connecting-ip"]
+ else:
+ ip = request.client.host
+ return ip
+
+
+def add_text(state, model_selector, text, request: gr.Request):
+ ip = get_ip(request)
+ logger.info(f"add_text. ip: {ip}. len: {len(text)}")
+
+ if state is None:
+ state = State(model_selector)
+
+ if len(text) <= 0:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
+
+ flagged = moderation_filter(text, [state.model_name])
+ if flagged:
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
+ # overwrite the original text
+ text = MODERATION_MSG
+
+ conv = state.conv
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
+ no_change_btn,
+ ) * 5
+
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
+ conv.append_message(conv.roles[0], text)
+ conv.append_message(conv.roles[1], None)
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
+
+
+def post_process_code(code):
+ sep = "\n```"
+ if sep in code:
+ blocks = code.split(sep)
+ if len(blocks) % 2 == 1:
+ for i in range(1, len(blocks), 2):
+ blocks[i] = blocks[i].replace("\\_", "_")
+ code = sep.join(blocks)
+ return code
+
+
+def model_worker_stream_iter(
+ conv,
+ model_name,
+ worker_addr,
+ prompt,
+ temperature,
+ repetition_penalty,
+ top_p,
+ max_new_tokens,
+):
+ # Make requests
+ gen_params = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": temperature,
+ "repetition_penalty": repetition_penalty,
+ "top_p": top_p,
+ "max_new_tokens": max_new_tokens,
+ "stop": conv.stop_str,
+ "stop_token_ids": conv.stop_token_ids,
+ "echo": False,
+ }
+ logger.info(f"==== request ====\n{gen_params}")
+
+ # Stream output
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=gen_params,
+ stream=True,
+ timeout=WORKER_API_TIMEOUT,
+ )
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ yield data
+
+
+def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request):
+ ip = get_ip(request)
+ logger.info(f"bot_response. ip: {ip}")
+ start_tstamp = time.time()
+ temperature = float(temperature)
+ top_p = float(top_p)
+ max_new_tokens = int(max_new_tokens)
+
+ if state.skip_next:
+ # This generate call is skipped due to invalid inputs
+ state.skip_next = False
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
+ return
+
+ conv, model_name = state.conv, state.model_name
+ if model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"]:
+ prompt = conv.to_openai_api_messages()
+ stream_iter = openai_api_stream_iter(
+ model_name, prompt, temperature, top_p, max_new_tokens
+ )
+ elif model_name in ["claude-2", "claude-1", "claude-instant-1"]:
+ prompt = conv.get_prompt()
+ stream_iter = anthropic_api_stream_iter(
+ model_name, prompt, temperature, top_p, max_new_tokens
+ )
+ elif model_name == "palm-2":
+ stream_iter = palm_api_stream_iter(
+ state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens
+ )
+ elif model_name in openai_compatible_models_info:
+ model_info = openai_compatible_models_info[model_name]
+ prompt = conv.to_openai_api_messages()
+ stream_iter = openai_api_stream_iter(
+ model_info["model_name"],
+ prompt,
+ temperature,
+ top_p,
+ max_new_tokens,
+ api_base=model_info["api_base"],
+ api_key=model_info["api_key"],
+ )
+ else:
+ # Query worker address
+ ret = requests.post(
+ controller_url + "/get_worker_address", json={"model": model_name}
+ )
+ worker_addr = ret.json()["address"]
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
+
+ # No available worker
+ if worker_addr == "":
+ conv.update_last_message(SERVER_ERROR_MSG)
+ yield (
+ state,
+ state.to_gradio_chatbot(),
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
+ return
+
+ # Construct prompt.
+ # We need to call it here, so it will not be affected by "▌".
+ prompt = conv.get_prompt()
+ # Set repetition_penalty
+ if "t5" in model_name:
+ repetition_penalty = 1.2
+ else:
+ repetition_penalty = 1.0
+
+ stream_iter = model_worker_stream_iter(
+ conv,
+ model_name,
+ worker_addr,
+ prompt,
+ temperature,
+ repetition_penalty,
+ top_p,
+ max_new_tokens,
+ )
+
+ conv.update_last_message("▌")
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+
+ try:
+ for i, data in enumerate(stream_iter):
+ if data["error_code"] == 0:
+ output = data["text"].strip()
+ conv.update_last_message(output + "▌")
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+ else:
+ output = data["text"] + f"\n\n(error_code: {data['error_code']})"
+ conv.update_last_message(output)
+ yield (state, state.to_gradio_chatbot()) + (
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
+ return
+ output = data["text"].strip()
+ if "vicuna" in model_name:
+ output = post_process_code(output)
+ conv.update_last_message(output)
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
+ except requests.exceptions.RequestException as e:
+ conv.update_last_message(
+ f"{SERVER_ERROR_MSG}\n\n"
+ f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
+ )
+ yield (state, state.to_gradio_chatbot()) + (
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
+ return
+ except Exception as e:
+ conv.update_last_message(
+ f"{SERVER_ERROR_MSG}\n\n"
+ f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
+ )
+ yield (state, state.to_gradio_chatbot()) + (
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
+ return
+
+ finish_tstamp = time.time()
+ logger.info(f"{output}")
+
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(finish_tstamp, 4),
+ "type": "chat",
+ "model": model_name,
+ "gen_params": {
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_new_tokens": max_new_tokens,
+ },
+ "start": round(start_tstamp, 4),
+ "finish": round(finish_tstamp, 4),
+ "state": state.dict(),
+ "ip": get_ip(request),
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+block_css = """
+#notice_markdown {
+ font-size: 110%
+}
+#notice_markdown th {
+ display: none;
+}
+#notice_markdown td {
+ padding-top: 6px;
+ padding-bottom: 6px;
+}
+#leaderboard_markdown {
+ font-size: 110%
+}
+#leaderboard_markdown td {
+ padding-top: 6px;
+ padding-bottom: 6px;
+}
+#leaderboard_dataframe td {
+ line-height: 0.1em;
+}
+#about_markdown {
+ font-size: 110%
+}
+#input_box textarea {
+}
+footer {
+ display:none !important
+}
+.image-container {
+ display: flex;
+ align-items: center;
+ padding: 1px;
+}
+.image-container img {
+ margin: 0 30px;
+ height: 20px;
+ max-height: 100%;
+ width: auto;
+ max-width: 20%;
+}
+.image-about img {
+ margin: 0 30px;
+ margin-top: 30px;
+ height: 60px;
+ max-height: 100%;
+ width: auto;
+ float: left;
+}
+"""
+
+
+def get_model_description_md(models):
+ model_description_md = """
+| | | |
+| ---- | ---- | ---- |
+"""
+ ct = 0
+ visited = set()
+ for i, name in enumerate(models):
+ minfo = get_model_info(name)
+ if minfo.simple_name in visited:
+ continue
+ visited.add(minfo.simple_name)
+ one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
+
+ if ct % 3 == 0:
+ model_description_md += "|"
+ model_description_md += f" {one_model_md} |"
+ if ct % 3 == 2:
+ model_description_md += "\n"
+ ct += 1
+ return model_description_md
+
+
+def build_about():
+ about_markdown = f"""
+# About Us
+Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our code at [GitHub](https://github.com/lm-sys/FastChat) and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey!
+
+## Read More
+- Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/)
+- LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998)
+
+## Core Members
+[Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ)
+
+## Advisors
+[Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/)
+
+## Contact Us
+- Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at lmsys.org@gmail.com
+- File issues on [GitHub](https://github.com/lm-sys/FastChat)
+- Download our datasets and models on [HuggingFace](https://huggingface.co./lmsys)
+
+## Sponsors
+We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co./) for their generous sponsorship.
+Learn more about partnership [here](https://lmsys.org/donations/).
+
+
+"""
+
+ # state = gr.State()
+ gr.Markdown(about_markdown, elem_id="about_markdown")
+
+ # return [state]
+
+
+def build_single_model_ui(models, add_promotion_links=False):
+ promotion = (
+ """
+- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
+- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/)
+- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/)
+"""
+ if add_promotion_links
+ else ""
+ )
+
+ notice_markdown = f"""
+# 🏔️ Chat with Open Large Language Models
+{promotion}
+
+## 👉 Choose any model to chat
+"""
+
+ state = gr.State()
+ model_description_md = get_model_description_md(models)
+ gr.Markdown(notice_markdown + model_description_md, elem_id="notice_markdown")
+
+ with gr.Row(elem_id="model_selector_row"):
+ model_selector = gr.Dropdown(
+ choices=models,
+ value=models[0] if len(models) > 0 else "",
+ interactive=True,
+ show_label=False,
+ container=False,
+ )
+
+ chatbot = gr.Chatbot(
+ elem_id="chatbot",
+ label="Scroll down and start chatting",
+ height=550,
+ )
+ with gr.Row():
+ with gr.Column(scale=20):
+ textbox = gr.Textbox(
+ show_label=False,
+ placeholder="Enter your prompt here and press ENTER",
+ container=False,
+ elem_id="input_box",
+ )
+ with gr.Column(scale=1, min_width=50):
+ send_btn = gr.Button(value="Send", variant="primary")
+
+ with gr.Row() as button_row:
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
+
+ with gr.Accordion("Parameters", open=False) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.7,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Top P",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=16,
+ maximum=3072,
+ value=2048,
+ step=1,
+ interactive=True,
+ label="Max output tokens",
+ )
+
+ if add_promotion_links:
+ gr.Markdown(acknowledgment_md)
+
+ # Register listeners
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
+ upvote_btn.click(
+ upvote_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ downvote_btn.click(
+ downvote_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ flag_btn.click(
+ flag_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
+ bot_response,
+ [state, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
+ clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
+
+ model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
+
+ textbox.submit(
+ add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
+ ).then(
+ bot_response,
+ [state, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
+ send_btn.click(
+ add_text,
+ [state, model_selector, textbox],
+ [state, chatbot, textbox] + btn_list,
+ ).then(
+ bot_response,
+ [state, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
+
+ return [state, model_selector]
+
+
+def build_demo(models):
+ with gr.Blocks(
+ title="Chat with Open Large Language Models",
+ theme=gr.themes.Default(),
+ css=block_css,
+ ) as demo:
+ url_params = gr.JSON(visible=False)
+
+ state, model_selector = build_single_model_ui(models)
+
+ if args.model_list_mode not in ["once", "reload"]:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ if args.show_terms_of_use:
+ load_js = get_window_url_params_with_tos_js
+ else:
+ load_js = get_window_url_params_js
+
+ demo.load(
+ load_demo,
+ [url_params],
+ [
+ state,
+ model_selector,
+ ],
+ _js=load_js,
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument(
+ "--conv-template",
+ type=str,
+ default="megrez",
+ help="The address of the controller",
+ )
+ parser.add_argument(
+ "--share",
+ action="store_true",
+ help="Whether to generate a public, shareable link",
+ )
+ parser.add_argument(
+ "--controller-url",
+ type=str,
+ default="http://localhost:21001",
+ help="The address of the controller",
+ )
+ parser.add_argument(
+ "--concurrency-count",
+ type=int,
+ default=10,
+ help="The concurrency count of the gradio queue",
+ )
+ parser.add_argument(
+ "--model-list-mode",
+ type=str,
+ default="once",
+ choices=["once", "reload"],
+ help="Whether to load the model list once or reload the model list every time",
+ )
+ parser.add_argument(
+ "--moderate",
+ action="store_true",
+ help="Enable content moderation to block unsafe inputs",
+ )
+ parser.add_argument(
+ "--show-terms-of-use",
+ action="store_true",
+ help="Shows term of use before loading the demo",
+ )
+ parser.add_argument(
+ "--add-chatgpt",
+ action="store_true",
+ help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)",
+ )
+ parser.add_argument(
+ "--add-claude",
+ action="store_true",
+ help="Add Anthropic's Claude models (claude-2, claude-instant-1)",
+ )
+ parser.add_argument(
+ "--add-palm",
+ action="store_true",
+ help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)",
+ )
+ parser.add_argument(
+ "--register-openai-compatible-models",
+ type=str,
+ help="Register custom OpenAI API compatible models by loading them from a JSON file",
+ )
+ parser.add_argument(
+ "--gradio-auth-path",
+ type=str,
+ help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
+ )
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+ CONV_TEMPLATE = args.conv_template
+ # Set global variables
+ set_global_vars(args.controller_url, args.moderate)
+ models = get_model_list(
+ args.controller_url,
+ args.register_openai_compatible_models,
+ args.add_chatgpt,
+ args.add_claude,
+ args.add_palm,
+ )
+ # Set authorization credentials
+ auth = None
+ if args.gradio_auth_path is not None:
+ auth = parse_gradio_auth_creds(args.gradio_auth_path)
+
+ # Launch the demo
+ demo = build_demo(models)
+ ret = demo.queue(
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
+ ).launch(
+ server_name=args.host,
+ server_port=args.port,
+ share=args.share,
+ max_threads=200,
+ auth=auth,
+ )
+ from IPython import embed;embed()
\ No newline at end of file
diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py
new file mode 100644
index 0000000000000000000000000000000000000000..b918f9d6b65c91a2d0a8d1488e9290c207c39b30
--- /dev/null
+++ b/fastchat/serve/gradio_web_server_multi.py
@@ -0,0 +1,270 @@
+"""
+The gradio demo server with multiple tabs.
+It supports chatting with a single model or chatting with two models side-by-side.
+"""
+
+import argparse
+import pickle
+import time
+
+import gradio as gr
+
+from fastchat.constants import (
+ SESSION_EXPIRATION_TIME,
+)
+from fastchat.serve.gradio_block_arena_anony import (
+ build_side_by_side_ui_anony,
+ load_demo_side_by_side_anony,
+ set_global_vars_anony,
+)
+from fastchat.serve.gradio_block_arena_named import (
+ build_side_by_side_ui_named,
+ load_demo_side_by_side_named,
+ set_global_vars_named,
+)
+from fastchat.serve.gradio_web_server import (
+ set_global_vars,
+ block_css,
+ build_single_model_ui,
+ build_about,
+ get_model_list,
+ load_demo_single,
+ ip_expiration_dict,
+ get_ip,
+)
+from fastchat.serve.monitor.monitor import build_leaderboard_tab
+from fastchat.utils import (
+ build_logger,
+ get_window_url_params_js,
+ get_window_url_params_with_tos_js,
+ parse_gradio_auth_creds,
+)
+
+logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
+
+
+def load_demo(url_params, request: gr.Request):
+ global models
+
+ ip = get_ip(request)
+ logger.info(f"load_demo. ip: {ip}. params: {url_params}")
+ ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME
+
+ selected = 0
+ if "arena" in url_params:
+ selected = 0
+ elif "compare" in url_params:
+ selected = 1
+ elif "single" in url_params:
+ selected = 2
+ elif "leaderboard" in url_params:
+ selected = 3
+
+ if args.model_list_mode == "reload":
+ if args.anony_only_for_proprietary_model:
+ models = get_model_list(
+ args.controller_url,
+ args.register_openai_compatible_models,
+ False,
+ False,
+ False,
+ )
+ else:
+ models = get_model_list(
+ args.controller_url,
+ args.register_openai_compatible_models,
+ args.add_chatgpt,
+ args.add_claude,
+ args.add_palm,
+ )
+
+ single_updates = load_demo_single(models, url_params)
+
+ models_anony = list(models)
+ if args.anony_only_for_proprietary_model:
+ # Only enable these models in anony battles.
+ if args.add_chatgpt:
+ models_anony += [
+ "gpt-4",
+ "gpt-3.5-turbo",
+ "gpt-4-turbo",
+ "gpt-3.5-turbo-1106",
+ ]
+ if args.add_claude:
+ models_anony += ["claude-2", "claude-1", "claude-instant-1"]
+ if args.add_palm:
+ models_anony += ["palm-2"]
+ models_anony = list(set(models_anony))
+
+ side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params)
+ side_by_side_named_updates = load_demo_side_by_side_named(models, url_params)
+ return (
+ (gr.Tabs.update(selected=selected),)
+ + single_updates
+ + side_by_side_anony_updates
+ + side_by_side_named_updates
+ )
+
+
+def build_demo(models, elo_results_file, leaderboard_table_file):
+ text_size = gr.themes.sizes.text_md
+ with gr.Blocks(
+ title="Chat with Open Large Language Models",
+ theme=gr.themes.Default(text_size=text_size),
+ css=block_css,
+ ) as demo:
+ with gr.Tabs() as tabs:
+ with gr.Tab("Arena (battle)", id=0):
+ side_by_side_anony_list = build_side_by_side_ui_anony(models)
+
+ with gr.Tab("Arena (side-by-side)", id=1):
+ side_by_side_named_list = build_side_by_side_ui_named(models)
+
+ with gr.Tab("Direct Chat", id=2):
+ single_model_list = build_single_model_ui(
+ models, add_promotion_links=True
+ )
+ if elo_results_file:
+ with gr.Tab("Leaderboard", id=3):
+ build_leaderboard_tab(elo_results_file, leaderboard_table_file)
+ with gr.Tab("About Us", id=4):
+ about = build_about()
+
+ url_params = gr.JSON(visible=False)
+
+ if args.model_list_mode not in ["once", "reload"]:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ if args.show_terms_of_use:
+ load_js = get_window_url_params_with_tos_js
+ else:
+ load_js = get_window_url_params_js
+
+ demo.load(
+ load_demo,
+ [url_params],
+ [tabs]
+ + single_model_list
+ + side_by_side_anony_list
+ + side_by_side_named_list,
+ _js=load_js,
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument(
+ "--share",
+ action="store_true",
+ help="Whether to generate a public, shareable link",
+ )
+ parser.add_argument(
+ "--controller-url",
+ type=str,
+ default="http://localhost:21001",
+ help="The address of the controller",
+ )
+ parser.add_argument(
+ "--concurrency-count",
+ type=int,
+ default=10,
+ help="The concurrency count of the gradio queue",
+ )
+ parser.add_argument(
+ "--model-list-mode",
+ type=str,
+ default="once",
+ choices=["once", "reload"],
+ help="Whether to load the model list once or reload the model list every time.",
+ )
+ parser.add_argument(
+ "--moderate",
+ action="store_true",
+ help="Enable content moderation to block unsafe inputs",
+ )
+ parser.add_argument(
+ "--show-terms-of-use",
+ action="store_true",
+ help="Shows term of use before loading the demo",
+ )
+ parser.add_argument(
+ "--add-chatgpt",
+ action="store_true",
+ help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)",
+ )
+ parser.add_argument(
+ "--add-claude",
+ action="store_true",
+ help="Add Anthropic's Claude models (claude-2, claude-instant-1)",
+ )
+ parser.add_argument(
+ "--add-palm",
+ action="store_true",
+ help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)",
+ )
+ parser.add_argument(
+ "--anony-only-for-proprietary-model",
+ action="store_true",
+ help="Only add ChatGPT, Claude, Bard under anony battle tab",
+ )
+ parser.add_argument(
+ "--register-openai-compatible-models",
+ type=str,
+ help="Register custom OpenAI API compatible models by loading them from a JSON file",
+ )
+ parser.add_argument(
+ "--gradio-auth-path",
+ type=str,
+ help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
+ default=None,
+ )
+ parser.add_argument(
+ "--elo-results-file", type=str, help="Load leaderboard results and plots"
+ )
+ parser.add_argument(
+ "--leaderboard-table-file", type=str, help="Load leaderboard results and plots"
+ )
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ # Set global variables
+ set_global_vars(args.controller_url, args.moderate)
+ set_global_vars_named(args.moderate)
+ set_global_vars_anony(args.moderate)
+ if args.anony_only_for_proprietary_model:
+ models = get_model_list(
+ args.controller_url,
+ args.register_openai_compatible_models,
+ False,
+ False,
+ False,
+ )
+ else:
+ models = get_model_list(
+ args.controller_url,
+ args.register_openai_compatible_models,
+ args.add_chatgpt,
+ args.add_claude,
+ args.add_palm,
+ )
+
+ # Set authorization credentials
+ auth = None
+ if args.gradio_auth_path is not None:
+ auth = parse_gradio_auth_creds(args.gradio_auth_path)
+
+ # Launch the demo
+ demo = build_demo(models, args.elo_results_file, args.leaderboard_table_file)
+ demo.queue(
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
+ ).launch(
+ server_name=args.host,
+ server_port=args.port,
+ share=args.share,
+ max_threads=200,
+ auth=auth,
+ )
diff --git a/fastchat/serve/huggingface_api.py b/fastchat/serve/huggingface_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a49bf5f175995d20c339906a84c71b97a49fae6
--- /dev/null
+++ b/fastchat/serve/huggingface_api.py
@@ -0,0 +1,73 @@
+"""
+Use FastChat with Hugging Face generation APIs.
+
+Usage:
+python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5
+python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0
+"""
+import argparse
+
+import torch
+
+from fastchat.model import load_model, get_conversation_template, add_model_args
+
+
+@torch.inference_mode()
+def main(args):
+ # Load model
+ model, tokenizer = load_model(
+ args.model_path,
+ device=args.device,
+ num_gpus=args.num_gpus,
+ max_gpu_memory=args.max_gpu_memory,
+ load_8bit=args.load_8bit,
+ cpu_offloading=args.cpu_offloading,
+ revision=args.revision,
+ debug=args.debug,
+ )
+
+ # Build the prompt with a conversation template
+ msg = args.message
+ conv = get_conversation_template(args.model_path)
+ conv.append_message(conv.roles[0], msg)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ # Run inference
+ inputs = tokenizer([prompt], return_tensors="pt").to(args.device)
+ output_ids = model.generate(
+ **inputs,
+ do_sample=True if args.temperature > 1e-5 else False,
+ temperature=args.temperature,
+ repetition_penalty=args.repetition_penalty,
+ max_new_tokens=args.max_new_tokens,
+ )
+
+ if model.config.is_encoder_decoder:
+ output_ids = output_ids[0]
+ else:
+ output_ids = output_ids[0][len(inputs["input_ids"][0]) :]
+ outputs = tokenizer.decode(
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
+ )
+
+ # Print results
+ print(f"{conv.roles[0]}: {msg}")
+ print(f"{conv.roles[1]}: {outputs}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ add_model_args(parser)
+ parser.add_argument("--temperature", type=float, default=0.7)
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--message", type=str, default="Hello! Who are you?")
+ args = parser.parse_args()
+
+ # Reset default repetition penalty for T5 models.
+ if "t5" in args.model_path and args.repetition_penalty == 1.0:
+ args.repetition_penalty = 1.2
+
+ main(args)
diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eef50e472ea739c17fac7ff586f775f8a9d1f4f
--- /dev/null
+++ b/fastchat/serve/huggingface_api_worker.py
@@ -0,0 +1,391 @@
+"""
+A model worker that calls huggingface inference endpoint.
+
+Register models in a JSON file with the following format:
+{
+ "falcon-180b-chat": {
+ "model_path": "tiiuae/falcon-180B-chat",
+ "api_base": "https://api-inference.huggingface.co/models",
+ "token": "hf_xxx",
+ "context_length": 2048,
+ "model_names": "falcon-180b-chat",
+ "conv_template": null
+ }
+}
+
+"model_path", "api_base", "token", and "context_length" are necessary, while others are optional.
+"""
+import argparse
+import asyncio
+import json
+import uuid
+from typing import List, Optional
+
+import requests
+import uvicorn
+from fastapi import BackgroundTasks, FastAPI, Request
+from fastapi.responses import JSONResponse, StreamingResponse
+from huggingface_hub import InferenceClient
+
+from fastchat.constants import SERVER_ERROR_MSG, ErrorCode
+from fastchat.serve.base_model_worker import BaseModelWorker
+from fastchat.utils import build_logger
+
+worker_id = str(uuid.uuid4())[:8]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+
+workers = []
+worker_map = {}
+app = FastAPI()
+
+
+# reference to
+# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392
+def get_gen_kwargs(
+ params,
+ seed: Optional[int] = None,
+):
+ stop = params.get("stop", None)
+ if isinstance(stop, list):
+ stop_sequences = stop
+ elif isinstance(stop, str):
+ stop_sequences = [stop]
+ else:
+ stop_sequences = []
+ gen_kwargs = {
+ "do_sample": True,
+ "return_full_text": bool(params.get("echo", False)),
+ "max_new_tokens": int(params.get("max_new_tokens", 256)),
+ "top_p": float(params.get("top_p", 1.0)),
+ "temperature": float(params.get("temperature", 1.0)),
+ "stop_sequences": stop_sequences,
+ "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
+ "top_k": params.get("top_k", None),
+ "seed": seed,
+ }
+ if gen_kwargs["top_p"] == 1:
+ gen_kwargs["top_p"] = 0.9999999
+ if gen_kwargs["top_p"] == 0:
+ gen_kwargs.pop("top_p")
+ if gen_kwargs["temperature"] == 0:
+ gen_kwargs.pop("temperature")
+ gen_kwargs["do_sample"] = False
+ return gen_kwargs
+
+
+def could_be_stop(text, stop):
+ for s in stop:
+ if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)):
+ return True
+ return False
+
+
+class HuggingfaceApiWorker(BaseModelWorker):
+ def __init__(
+ self,
+ controller_addr: str,
+ worker_addr: str,
+ worker_id: str,
+ model_path: str,
+ api_base: str,
+ token: str,
+ context_length: int,
+ model_names: List[str],
+ limit_worker_concurrency: int,
+ no_register: bool,
+ conv_template: Optional[str] = None,
+ seed: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ controller_addr,
+ worker_addr,
+ worker_id,
+ model_path,
+ model_names,
+ limit_worker_concurrency,
+ conv_template=conv_template,
+ )
+
+ self.model_path = model_path
+ self.api_base = api_base
+ self.token = token
+ self.context_len = context_length
+ self.seed = seed
+
+ logger.info(
+ f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..."
+ )
+
+ if not no_register:
+ self.init_heart_beat()
+
+ def count_token(self, params):
+ # No tokenizer here
+ ret = {
+ "count": 0,
+ "error_code": 0,
+ }
+ return ret
+
+ def generate_stream_gate(self, params):
+ self.call_ct += 1
+
+ prompt = params["prompt"]
+ gen_kwargs = get_gen_kwargs(params, seed=self.seed)
+ stop = gen_kwargs["stop_sequences"]
+ if "falcon" in self.model_path and "chat" in self.model_path:
+ stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"])
+ stop = list(set(stop))
+ gen_kwargs["stop_sequences"] = stop
+
+ logger.info(f"prompt: {prompt}")
+ logger.info(f"gen_kwargs: {gen_kwargs}")
+
+ try:
+ if self.model_path == "":
+ url = f"{self.api_base}"
+ else:
+ url = f"{self.api_base}/{self.model_path}"
+ client = InferenceClient(url, token=self.token)
+ res = client.text_generation(
+ prompt, stream=True, details=True, **gen_kwargs
+ )
+
+ reason = None
+ text = ""
+ for chunk in res:
+ if chunk.token.special:
+ continue
+ text += chunk.token.text
+
+ s = next((x for x in stop if text.endswith(x)), None)
+ if s is not None:
+ text = text[: -len(s)]
+ reason = "stop"
+ break
+ if could_be_stop(text, stop):
+ continue
+ if (
+ chunk.details is not None
+ and chunk.details.finish_reason is not None
+ ):
+ reason = chunk.details.finish_reason
+ if reason not in ["stop", "length"]:
+ reason = None
+ ret = {
+ "text": text,
+ "error_code": 0,
+ "finish_reason": reason,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.INTERNAL_ERROR,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+ def generate_gate(self, params):
+ for x in self.generate_stream_gate(params):
+ pass
+ return json.loads(x[:-1].decode())
+
+ def get_embeddings(self, params):
+ raise NotImplementedError()
+
+
+def release_worker_semaphore(worker):
+ worker.semaphore.release()
+
+
+def acquire_worker_semaphore(worker):
+ if worker.semaphore is None:
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
+ return worker.semaphore.acquire()
+
+
+def create_background_tasks(worker):
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(lambda: release_worker_semaphore(worker))
+ return background_tasks
+
+
+@app.post("/worker_generate_stream")
+async def api_generate_stream(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ await acquire_worker_semaphore(worker)
+ generator = worker.generate_stream_gate(params)
+ background_tasks = create_background_tasks(worker)
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_generate")
+async def api_generate(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ await acquire_worker_semaphore(worker)
+ output = worker.generate_gate(params)
+ release_worker_semaphore(worker)
+ return JSONResponse(output)
+
+
+@app.post("/worker_get_embeddings")
+async def api_get_embeddings(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ await acquire_worker_semaphore(worker)
+ embedding = worker.get_embeddings(params)
+ release_worker_semaphore(worker)
+ return JSONResponse(content=embedding)
+
+
+@app.post("/worker_get_status")
+async def api_get_status(request: Request):
+ return {
+ "model_names": [m for w in workers for m in w.model_names],
+ "speed": 1,
+ "queue_length": sum([w.get_queue_length() for w in workers]),
+ }
+
+
+@app.post("/count_token")
+async def api_count_token(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ return worker.count_token(params)
+
+
+@app.post("/worker_get_conv_template")
+async def api_get_conv(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ return worker.get_conv_template()
+
+
+@app.post("/model_details")
+async def api_model_details(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ return {"context_length": worker.context_len}
+
+
+def create_huggingface_api_worker():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ # all model-related parameters are listed in --model-info-file
+ parser.add_argument(
+ "--model-info-file",
+ type=str,
+ required=True,
+ help="Huggingface API model's info file path",
+ )
+
+ parser.add_argument(
+ "--limit-worker-concurrency",
+ type=int,
+ default=5,
+ help="Limit the model concurrency to prevent OOM.",
+ )
+ parser.add_argument("--no-register", action="store_true")
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Overwrite the random seed for each generation.",
+ )
+ args = parser.parse_args()
+
+ with open(args.model_info_file, "r", encoding="UTF-8") as f:
+ model_info = json.load(f)
+
+ logger.info(f"args: {args}")
+
+ model_path_list = []
+ api_base_list = []
+ token_list = []
+ context_length_list = []
+ model_names_list = []
+ conv_template_list = []
+
+ for m in model_info:
+ model_path_list.append(model_info[m]["model_path"])
+ api_base_list.append(model_info[m]["api_base"])
+ token_list.append(model_info[m]["token"])
+
+ context_length = model_info[m]["context_length"]
+ model_names = model_info[m].get("model_names", [m.split("/")[-1]])
+ if isinstance(model_names, str):
+ model_names = [model_names]
+ conv_template = model_info[m].get("conv_template", None)
+
+ context_length_list.append(context_length)
+ model_names_list.append(model_names)
+ conv_template_list.append(conv_template)
+
+ logger.info(f"Model paths: {model_path_list}")
+ logger.info(f"API bases: {api_base_list}")
+ logger.info(f"Tokens: {token_list}")
+ logger.info(f"Context lengths: {context_length_list}")
+ logger.info(f"Model names: {model_names_list}")
+ logger.info(f"Conv templates: {conv_template_list}")
+
+ for (
+ model_names,
+ conv_template,
+ model_path,
+ api_base,
+ token,
+ context_length,
+ ) in zip(
+ model_names_list,
+ conv_template_list,
+ model_path_list,
+ api_base_list,
+ token_list,
+ context_length_list,
+ ):
+ m = HuggingfaceApiWorker(
+ args.controller_address,
+ args.worker_address,
+ worker_id,
+ model_path,
+ api_base,
+ token,
+ context_length,
+ model_names,
+ args.limit_worker_concurrency,
+ no_register=args.no_register,
+ conv_template=conv_template,
+ seed=args.seed,
+ )
+ workers.append(m)
+ for name in model_names:
+ worker_map[name] = m
+
+ # register all the models
+ url = args.controller_address + "/register_worker"
+ data = {
+ "worker_name": workers[0].worker_addr,
+ "check_heart_beat": not args.no_register,
+ "worker_status": {
+ "model_names": [m for w in workers for m in w.model_names],
+ "speed": 1,
+ "queue_length": sum([w.get_queue_length() for w in workers]),
+ },
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ return args, workers
+
+
+if __name__ == "__main__":
+ args, workers = create_huggingface_api_worker()
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/fastchat/serve/inference.py b/fastchat/serve/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..888a13d00c0f8a723cf417035dcf9d0ab65e0076
--- /dev/null
+++ b/fastchat/serve/inference.py
@@ -0,0 +1,596 @@
+"""Inference for FastChat models."""
+import abc
+import gc
+import json
+import math
+import os
+import sys
+import time
+from typing import Iterable, Optional, Dict
+import warnings
+
+import psutil
+import torch
+from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ LlamaTokenizer,
+ LlamaForCausalLM,
+ AutoModel,
+ AutoModelForSeq2SeqLM,
+ T5Tokenizer,
+ AutoConfig,
+)
+from transformers.generation.logits_process import (
+ LogitsProcessorList,
+ RepetitionPenaltyLogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+)
+
+from fastchat.conversation import get_conv_template, SeparatorStyle
+from fastchat.model.model_adapter import (
+ load_model,
+ get_conversation_template,
+ get_generate_stream_function,
+)
+from fastchat.modules.awq import AWQConfig
+from fastchat.modules.gptq import GptqConfig
+from fastchat.modules.exllama import ExllamaConfig
+from fastchat.modules.xfastertransformer import XftConfig
+from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
+
+
+def prepare_logits_processor(
+ temperature: float, repetition_penalty: float, top_p: float, top_k: int
+) -> LogitsProcessorList:
+ processor_list = LogitsProcessorList()
+ # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
+ if temperature >= 1e-5 and temperature != 1.0:
+ processor_list.append(TemperatureLogitsWarper(temperature))
+ if repetition_penalty > 1.0:
+ processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
+ if 1e-8 <= top_p < 1.0:
+ processor_list.append(TopPLogitsWarper(top_p))
+ if top_k > 0:
+ processor_list.append(TopKLogitsWarper(top_k))
+ return processor_list
+
+
+@torch.inference_mode()
+def generate_stream(
+ model,
+ tokenizer,
+ params: Dict,
+ device: str,
+ context_len: int,
+ stream_interval: int = 2,
+ judge_sent_end: bool = False,
+):
+ if hasattr(model, "device"):
+ device = model.device
+
+ # Read parameters
+ prompt = params["prompt"]
+ len_prompt = len(prompt)
+ temperature = float(params.get("temperature", 1.0))
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ top_k = int(params.get("top_k", -1)) # -1 means disable
+ max_new_tokens = int(params.get("max_new_tokens", 256))
+ logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
+ echo = bool(params.get("echo", True))
+ stop_str = params.get("stop", None)
+ stop_token_ids = params.get("stop_token_ids", None) or []
+ if tokenizer.eos_token_id not in stop_token_ids:
+ stop_token_ids.append(tokenizer.eos_token_id)
+ if params.get('none_stop'):
+ stop_token_ids = []
+ skip_special_tokens = params.get('skip_special_tokens')
+
+ logits_processor = prepare_logits_processor(
+ temperature, repetition_penalty, top_p, top_k
+ )
+ input_ids = tokenizer(prompt).input_ids
+
+ if model.config.is_encoder_decoder:
+ max_src_len = context_len
+ else: # truncate
+ max_src_len = context_len - max_new_tokens - 1
+
+ input_ids = input_ids[-max_src_len:]
+ output_ids = list(input_ids)
+ input_echo_len = len(input_ids)
+
+ if model.config.is_encoder_decoder:
+ if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
+ raise NotImplementedError
+ encoder_output = model.encoder(
+ input_ids=torch.as_tensor([input_ids], device=device)
+ )[0]
+ start_ids = torch.as_tensor(
+ [[model.generation_config.decoder_start_token_id]],
+ dtype=torch.int64,
+ device=device,
+ )
+ else:
+ start_ids = torch.as_tensor([input_ids], device=device)
+
+ past_key_values = out = None
+ token_logprobs = [None] # The first token has no logprobs.
+ sent_interrupt = False
+ finish_reason = None
+ for i in range(max_new_tokens):
+ if i == 0: # prefill
+ if model.config.is_encoder_decoder:
+ out = model.decoder(
+ input_ids=start_ids,
+ encoder_hidden_states=encoder_output,
+ use_cache=True,
+ )
+ logits = model.lm_head(out[0])
+ else:
+ out = model(input_ids=start_ids, use_cache=True)
+ logits = out.logits
+ past_key_values = out.past_key_values
+
+ if logprobs is not None:
+ # Prefull logprobs for the prompt.
+ shift_input_ids = start_ids[..., 1:].contiguous()
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
+ for label_id, logit in zip(
+ shift_input_ids[0].tolist(), shift_logits[0]
+ ):
+ token_logprobs.append(logit[label_id])
+ else: # decoding
+ if model.config.is_encoder_decoder:
+ out = model.decoder(
+ input_ids=torch.as_tensor(
+ [[token] if not sent_interrupt else output_ids],
+ device=device,
+ ),
+ encoder_hidden_states=encoder_output,
+ use_cache=True,
+ past_key_values=past_key_values if not sent_interrupt else None,
+ )
+ sent_interrupt = False
+
+ logits = model.lm_head(out[0])
+ else:
+ out = model(
+ input_ids=torch.as_tensor(
+ [[token] if not sent_interrupt else output_ids],
+ device=device,
+ ),
+ use_cache=True,
+ past_key_values=past_key_values if not sent_interrupt else None,
+ )
+ sent_interrupt = False
+ logits = out.logits
+ past_key_values = out.past_key_values
+
+ if logits_processor:
+ if repetition_penalty > 1.0:
+ tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
+ else:
+ tmp_output_ids = None
+ last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
+ else:
+ last_token_logits = logits[0, -1, :]
+
+ if device == "mps":
+ # Switch to CPU by avoiding some bugs in mps backend.
+ last_token_logits = last_token_logits.float().to("cpu")
+
+ if temperature < 1e-5 or top_p < 1e-8: # greedy
+ _, indices = torch.topk(last_token_logits, 2)
+ tokens = [int(index) for index in indices.tolist()]
+ else:
+ probs = torch.softmax(last_token_logits, dim=-1)
+ indices = torch.multinomial(probs, num_samples=2)
+ tokens = [int(token) for token in indices.tolist()]
+ token = tokens[0]
+ output_ids.append(token)
+ if logprobs is not None:
+ # Cannot use last_token_logits because logprobs is based on raw logits.
+ token_logprobs.append(
+ torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
+ )
+
+ if token in stop_token_ids:
+ stopped = True
+ else:
+ stopped = False
+
+ # Yield the output tokens
+ if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
+ if echo:
+ tmp_output_ids = output_ids
+ rfind_start = len_prompt
+ else:
+ tmp_output_ids = output_ids[input_echo_len:]
+ rfind_start = 0
+
+ output = tokenizer.decode(
+ tmp_output_ids,
+ skip_special_tokens=skip_special_tokens,
+ spaces_between_special_tokens=False,
+ clean_up_tokenization_spaces=True,
+ )
+ ret_logprobs = None
+ if logprobs is not None:
+ ret_logprobs = {
+ "text_offset": [],
+ "tokens": [
+ tokenizer.decode(token)
+ for token in (
+ output_ids if echo else output_ids[input_echo_len:]
+ )
+ ],
+ "token_logprobs": token_logprobs
+ if echo
+ else token_logprobs[input_echo_len:],
+ "top_logprobs": [{}]
+ * len(token_logprobs if echo else token_logprobs[input_echo_len:]),
+ }
+ # Compute text_offset
+ curr_pos = 0
+ for text in ret_logprobs["tokens"]:
+ ret_logprobs["text_offset"].append(curr_pos)
+ curr_pos += len(text)
+
+ # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
+ if judge_sent_end and stopped and not is_sentence_complete(output):
+ if len(tokens) > 1:
+ token = tokens[1]
+ output_ids[-1] = token
+ else:
+ output_ids.pop()
+ stopped = False
+ sent_interrupt = True
+
+ partially_stopped = False
+ if stop_str:
+ if isinstance(stop_str, str):
+ pos = output.rfind(stop_str, rfind_start)
+ if pos != -1:
+ output = output[:pos]
+ stopped = True
+ else:
+ partially_stopped = is_partial_stop(output, stop_str)
+ elif isinstance(stop_str, Iterable):
+ for each_stop in stop_str:
+ pos = output.rfind(each_stop, rfind_start)
+ if pos != -1:
+ output = output[:pos]
+ stopped = True
+ break
+ else:
+ partially_stopped = is_partial_stop(output, each_stop)
+ if partially_stopped:
+ break
+ else:
+ raise ValueError("Invalid stop field type.")
+
+ # Prevent yielding partial stop sequence
+ if not partially_stopped:
+ yield {
+ "text": output,
+ "logprobs": ret_logprobs,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": None,
+ }
+
+ if stopped:
+ break
+
+ # Finish stream event, which contains finish reason
+ else:
+ finish_reason = "length"
+
+ if stopped:
+ finish_reason = "stop"
+
+ yield {
+ "text": output,
+ "logprobs": ret_logprobs,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": finish_reason,
+ }
+
+ # Clean
+ del past_key_values, out
+ gc.collect()
+ torch.cuda.empty_cache()
+ if device == "xpu":
+ torch.xpu.empty_cache()
+ if device == "npu":
+ torch.npu.empty_cache()
+
+
+class ChatIO(abc.ABC):
+ @abc.abstractmethod
+ def prompt_for_input(self, role: str) -> str:
+ """Prompt for input from a role."""
+
+ @abc.abstractmethod
+ def prompt_for_output(self, role: str):
+ """Prompt for output from a role."""
+
+ @abc.abstractmethod
+ def stream_output(self, output_stream):
+ """Stream output."""
+
+ @abc.abstractmethod
+ def print_output(self, text: str):
+ """Print output."""
+
+
+def convert_message_format(message):
+ formated_message = []
+ for i, turn in enumerate(message):
+ role = 'user' if i % 2 == 0 else 'assistant'
+ formated_message.append({'role': role, 'content': turn[1]})
+
+ data = {
+ 'conversations': formated_message,
+ 'idx': -1,
+ 'tinder': 'badcase',
+ 'model': '',
+ 'tokens_in': 0,
+ 'tokens_out': 0,
+ }
+
+ return data
+
+
+def chat_loop(
+ model_path: str,
+ device: str,
+ num_gpus: int,
+ max_gpu_memory: str,
+ dtype: Optional[torch.dtype],
+ load_8bit: bool,
+ cpu_offloading: bool,
+ conv_template: Optional[str],
+ conv_system_msg: Optional[str],
+ temperature: float,
+ repetition_penalty: float,
+ max_new_tokens: int,
+ chatio: ChatIO,
+ gptq_config: Optional[GptqConfig] = None,
+ awq_config: Optional[AWQConfig] = None,
+ exllama_config: Optional[ExllamaConfig] = None,
+ xft_config: Optional[XftConfig] = None,
+ revision: str = "main",
+ judge_sent_end: bool = True,
+ debug: bool = True,
+ history: bool = True,
+):
+ # Model
+ model, tokenizer = load_model(
+ model_path,
+ device=device,
+ num_gpus=num_gpus,
+ max_gpu_memory=max_gpu_memory,
+ dtype=dtype,
+ load_8bit=load_8bit,
+ cpu_offloading=cpu_offloading,
+ gptq_config=gptq_config,
+ awq_config=awq_config,
+ exllama_config=exllama_config,
+ xft_config=xft_config,
+ revision=revision,
+ debug=debug,
+ )
+ generate_stream_func = get_generate_stream_function(model, model_path)
+
+ model_type = str(type(model)).lower()
+ is_t5 = "t5" in model_type
+ is_codet5p = "codet5p" in model_type
+ is_xft = "xft" in model_type
+
+ # Hardcode T5's default repetition penalty to be 1.2
+ if is_t5 and repetition_penalty == 1.0:
+ repetition_penalty = 1.2
+
+ # Set context length
+ context_len = get_context_length(model.config)
+
+ # Chat
+ def new_chat():
+ if conv_template:
+ conv = get_conv_template(conv_template)
+ else:
+ conv = get_conversation_template(model_path)
+ if conv_system_msg is not None:
+ conv.set_system_message(conv_system_msg)
+ return conv
+
+ def reload_conv(conv):
+ """
+ Reprints the conversation from the start.
+ """
+ for message in conv.messages[conv.offset :]:
+ chatio.prompt_for_output(message[0])
+ chatio.print_output(message[1])
+
+ conv = None
+
+ while True:
+ if not history or not conv:
+ conv = new_chat()
+
+ try:
+ inp = chatio.prompt_for_input(conv.roles[0])
+ except EOFError:
+ inp = ""
+
+ if inp == "!!exit":# or not inp:
+ print("exit...")
+ break
+ elif inp == "!!reset":
+ print("resetting...")
+ conv = new_chat()
+ continue
+ elif inp == "!!remove":
+ print("removing last message...")
+ if len(conv.messages) > conv.offset:
+ # Assistant
+ if conv.messages[-1][0] == conv.roles[1]:
+ conv.messages.pop()
+ # User
+ if conv.messages[-1][0] == conv.roles[0]:
+ conv.messages.pop()
+ reload_conv(conv)
+ else:
+ print("No messages to remove.")
+ continue
+ elif inp == "!!regen":
+ print("regenerating last message...")
+ if len(conv.messages) > conv.offset:
+ # Assistant
+ if conv.messages[-1][0] == conv.roles[1]:
+ conv.messages.pop()
+ # User
+ if conv.messages[-1][0] == conv.roles[0]:
+ reload_conv(conv)
+ # Set inp to previous message
+ inp = conv.messages.pop()[1]
+ else:
+ # Shouldn't happen in normal circumstances
+ print("No user message to regenerate from.")
+ continue
+ else:
+ print("No messages to regenerate.")
+ continue
+ elif inp.startswith("!!save"):
+ args = inp.split(" ", 1)
+
+ if len(args) != 2:
+ print("usage: !!save ")
+ continue
+ else:
+ filename = args[1]
+
+ # Add .json if extension not present
+ if not "." in filename:
+ filename += ".json"
+
+ print("saving...", filename)
+ with open(filename, "w", encoding="utf-8") as outfile:
+ json.dump(conv.dict(), outfile, ensure_ascii=False)
+ continue
+ elif inp.startswith("!!badcase"):
+ args = inp.split(" ", 1)
+
+ if len(args) != 2:
+ print("usage: !!save ")
+ continue
+ else:
+ filename = args[1]
+
+ # Add .json if extension not present
+ if not "." in filename:
+ filename += ".jsonl"
+
+ print("saving...", filename)
+ with open(filename, "a+", encoding="utf-8") as outfile:
+ data = convert_message_format(conv.messages)
+ json.dump(data, outfile, ensure_ascii=False)
+ outfile.write('\n')
+ continue
+ elif inp.startswith("!!load"):
+ args = inp.split(" ", 1)
+
+ if len(args) != 2:
+ print("usage: !!load ")
+ continue
+ else:
+ filename = args[1]
+
+ # Check if file exists and add .json if needed
+ if not os.path.exists(filename):
+ if (not filename.endswith(".json")) and os.path.exists(
+ filename + ".json"
+ ):
+ filename += ".json"
+ else:
+ print("file not found:", filename)
+ continue
+
+ print("loading...", filename)
+ with open(filename, "r") as infile:
+ new_conv = json.load(infile)
+
+ conv = get_conv_template(new_conv["template_name"])
+ conv.set_system_message(new_conv["system_message"])
+ conv.messages = new_conv["messages"]
+ reload_conv(conv)
+ continue
+
+ conv.append_message(conv.roles[0], inp)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt(tokenizer)
+
+ if is_codet5p: # codet5p is a code completion model.
+ prompt = inp
+
+ gen_params = {
+ "model": model_path,
+ "prompt": prompt,
+ "temperature": temperature,
+ "repetition_penalty": repetition_penalty,
+ "max_new_tokens": max_new_tokens,
+ "stop": conv.stop_str,
+ "stop_token_ids": conv.stop_token_ids,
+ "none_stop": conv.none_stop,
+ "skip_special_tokens": conv.skip_special_tokens,
+ "echo": False,
+ }
+
+ try:
+ chatio.prompt_for_output(conv.roles[1])
+ output_stream = generate_stream_func(
+ model,
+ tokenizer,
+ gen_params,
+ device,
+ context_len=context_len,
+ judge_sent_end=judge_sent_end,
+ )
+ t = time.time()
+ outputs = chatio.stream_output(output_stream)
+ duration = time.time() - t
+ conv.update_last_message(outputs.strip())
+
+ if debug:
+ num_tokens = len(tokenizer.encode(outputs))
+ msg = {
+ "conv_template": conv.name,
+ "prompt": prompt,
+ "outputs": outputs,
+ "speed (token/s)": round(num_tokens / duration, 2),
+ }
+ print(f"\n{msg}\n")
+
+ except KeyboardInterrupt:
+ print("stopped generation.")
+ # If generation didn't finish
+ if conv.messages[-1][1] is None:
+ conv.messages.pop()
+ # Remove last user message, so there isn't a double up
+ if conv.messages[-1][0] == conv.roles[0]:
+ conv.messages.pop()
+
+ reload_conv(conv)
diff --git a/fastchat/serve/launch_all_serve.py b/fastchat/serve/launch_all_serve.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f4ad7b0b134d1699ff8ba0d95d8039ec3c1f204
--- /dev/null
+++ b/fastchat/serve/launch_all_serve.py
@@ -0,0 +1,284 @@
+"""
+Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022"
+
+Workers are listed in format of `model-path`@`host`@`port`
+
+The key mechanism behind this scripts is:
+ 1, execute shell cmd to launch the controller/worker/openai-api-server;
+ 2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly.
+Note that a few of non-critical `fastchat.serve` cmd options are not supported currently.
+"""
+import sys
+import os
+
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+
+import subprocess
+import re
+import argparse
+
+LOGDIR = "./logs/"
+
+if not os.path.exists(LOGDIR):
+ os.makedirs(LOGDIR)
+
+parser = argparse.ArgumentParser()
+# ------multi worker-----------------
+parser.add_argument(
+ "--model-path-address",
+ default="THUDM/chatglm2-6b@localhost@20002",
+ nargs="+",
+ type=str,
+ help="model path, host, and port, formatted as model-path@host@port",
+)
+# ---------------controller-------------------------
+
+parser.add_argument("--controller-host", type=str, default="localhost")
+parser.add_argument("--controller-port", type=int, default=21001)
+parser.add_argument(
+ "--dispatch-method",
+ type=str,
+ choices=["lottery", "shortest_queue"],
+ default="shortest_queue",
+)
+controller_args = ["controller-host", "controller-port", "dispatch-method"]
+
+# ----------------------worker------------------------------------------
+
+parser.add_argument("--worker-host", type=str, default="localhost")
+parser.add_argument("--worker-port", type=int, default=21002)
+# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+# parser.add_argument(
+# "--controller-address", type=str, default="http://localhost:21001"
+# )
+parser.add_argument(
+ "--model-path",
+ type=str,
+ default="lmsys/vicuna-7b-v1.5",
+ help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
+)
+parser.add_argument(
+ "--revision",
+ type=str,
+ default="main",
+ help="Hugging Face Hub model revision identifier",
+)
+parser.add_argument(
+ "--device",
+ type=str,
+ choices=["cpu", "cuda", "mps", "xpu", "npu"],
+ default="cuda",
+ help="The device type",
+)
+parser.add_argument(
+ "--gpus",
+ type=str,
+ default="0",
+ help="A single GPU like 1 or multiple GPUs like 0,2",
+)
+parser.add_argument("--num-gpus", type=int, default=1)
+parser.add_argument(
+ "--max-gpu-memory",
+ type=str,
+ help="The maximum memory per gpu. Use a string like '13Gib'",
+)
+parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization")
+parser.add_argument(
+ "--cpu-offloading",
+ action="store_true",
+ help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
+)
+parser.add_argument(
+ "--gptq-ckpt",
+ type=str,
+ default=None,
+ help="Load quantized model. The path to the local GPTQ checkpoint.",
+)
+parser.add_argument(
+ "--gptq-wbits",
+ type=int,
+ default=16,
+ choices=[2, 3, 4, 8, 16],
+ help="#bits to use for quantization",
+)
+parser.add_argument(
+ "--gptq-groupsize",
+ type=int,
+ default=-1,
+ help="Groupsize to use for quantization; default uses full row.",
+)
+parser.add_argument(
+ "--gptq-act-order",
+ action="store_true",
+ help="Whether to apply the activation order GPTQ heuristic",
+)
+parser.add_argument(
+ "--model-names",
+ type=lambda s: s.split(","),
+ help="Optional display comma separated names",
+)
+parser.add_argument(
+ "--limit-worker-concurrency",
+ type=int,
+ default=5,
+ help="Limit the model concurrency to prevent OOM.",
+)
+parser.add_argument("--stream-interval", type=int, default=2)
+parser.add_argument("--no-register", action="store_true")
+
+worker_args = [
+ "worker-host",
+ "worker-port",
+ "model-path",
+ "revision",
+ "device",
+ "gpus",
+ "num-gpus",
+ "max-gpu-memory",
+ "load-8bit",
+ "cpu-offloading",
+ "gptq-ckpt",
+ "gptq-wbits",
+ "gptq-groupsize",
+ "gptq-act-order",
+ "model-names",
+ "limit-worker-concurrency",
+ "stream-interval",
+ "no-register",
+ "controller-address",
+]
+# -----------------openai server---------------------------
+
+parser.add_argument("--server-host", type=str, default="localhost", help="host name")
+parser.add_argument("--server-port", type=int, default=8001, help="port number")
+parser.add_argument(
+ "--allow-credentials", action="store_true", help="allow credentials"
+)
+# parser.add_argument(
+# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
+# )
+# parser.add_argument(
+# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
+# )
+# parser.add_argument(
+# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
+# )
+parser.add_argument(
+ "--api-keys",
+ type=lambda s: s.split(","),
+ help="Optional list of comma separated API keys",
+)
+server_args = [
+ "server-host",
+ "server-port",
+ "allow-credentials",
+ "api-keys",
+ "controller-address",
+]
+
+args = parser.parse_args()
+
+args = argparse.Namespace(
+ **vars(args),
+ **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"},
+)
+
+if args.gpus:
+ if len(args.gpus.split(",")) < args.num_gpus:
+ raise ValueError(
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
+ )
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
+
+# 0,controller, model_worker, openai_api_server
+# 1, cmd options
+# 2,LOGDIR
+# 3, log file name
+base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
+
+# 0 LOGDIR
+#! 1 log file name
+# 2 controller, worker, openai_api_server
+base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
+ sleep 1s;
+ echo "wait {2} running"
+ done
+ echo '{2} running' """
+
+
+def string_args(args, args_list):
+ args_str = ""
+ for key, value in args._get_kwargs():
+ key = key.replace("_", "-")
+ if key not in args_list:
+ continue
+
+ key = key.split("-")[-1] if re.search("port|host", key) else key
+ if not value:
+ pass
+ # 1==True -> True
+ elif isinstance(value, bool) and value == True:
+ args_str += f" --{key} "
+ elif (
+ isinstance(value, list)
+ or isinstance(value, tuple)
+ or isinstance(value, set)
+ ):
+ value = " ".join(value)
+ args_str += f" --{key} {value} "
+ else:
+ args_str += f" --{key} {value} "
+
+ return args_str
+
+
+def launch_worker(item):
+ log_name = (
+ item.split("/")[-1]
+ .split("\\")[-1]
+ .replace("-", "_")
+ .replace("@", "_")
+ .replace(".", "_")
+ )
+
+ args.model_path, args.worker_host, args.worker_port = item.split("@")
+ print("*" * 80)
+ worker_str_args = string_args(args, worker_args)
+ print(worker_str_args)
+ worker_sh = base_launch_sh.format(
+ "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}"
+ )
+ worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker")
+ subprocess.run(worker_sh, shell=True, check=True)
+ subprocess.run(worker_check_sh, shell=True, check=True)
+
+
+def launch_all():
+ controller_str_args = string_args(args, controller_args)
+ controller_sh = base_launch_sh.format(
+ "controller", controller_str_args, LOGDIR, "controller"
+ )
+ controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller")
+ subprocess.run(controller_sh, shell=True, check=True)
+ subprocess.run(controller_check_sh, shell=True, check=True)
+
+ if isinstance(args.model_path_address, str):
+ launch_worker(args.model_path_address)
+ else:
+ for idx, item in enumerate(args.model_path_address):
+ print(f"loading {idx}th model:{item}")
+ launch_worker(item)
+
+ server_str_args = string_args(args, server_args)
+ server_sh = base_launch_sh.format(
+ "openai_api_server", server_str_args, LOGDIR, "openai_api_server"
+ )
+ server_check_sh = base_check_sh.format(
+ LOGDIR, "openai_api_server", "openai_api_server"
+ )
+ subprocess.run(server_sh, shell=True, check=True)
+ subprocess.run(server_check_sh, shell=True, check=True)
+
+
+if __name__ == "__main__":
+ launch_all()
diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..056c2f8954195d57a4d940c185b1504ae2862664
--- /dev/null
+++ b/fastchat/serve/model_worker.py
@@ -0,0 +1,363 @@
+"""
+A model worker that executes the model.
+"""
+import argparse
+import base64
+import gc
+import json
+import os
+from typing import List, Optional
+import uuid
+
+import torch
+import torch.nn.functional as F
+from transformers import set_seed
+import uvicorn
+
+from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
+from fastchat.model.model_adapter import (
+ load_model,
+ add_model_args,
+ get_generate_stream_function,
+)
+from fastchat.modules.awq import AWQConfig
+from fastchat.modules.exllama import ExllamaConfig
+from fastchat.modules.xfastertransformer import XftConfig
+from fastchat.modules.gptq import GptqConfig
+from fastchat.serve.base_model_worker import BaseModelWorker, app
+from fastchat.utils import (
+ build_logger,
+ get_context_length,
+ str_to_torch_dtype,
+)
+
+
+worker_id = str(uuid.uuid4())[:8]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+
+
+class ModelWorker(BaseModelWorker):
+ def __init__(
+ self,
+ controller_addr: str,
+ worker_addr: str,
+ worker_id: str,
+ model_path: str,
+ model_names: List[str],
+ limit_worker_concurrency: int,
+ no_register: bool,
+ device: str,
+ num_gpus: int,
+ max_gpu_memory: str,
+ dtype: Optional[torch.dtype] = None,
+ load_8bit: bool = False,
+ cpu_offloading: bool = False,
+ gptq_config: Optional[GptqConfig] = None,
+ awq_config: Optional[AWQConfig] = None,
+ exllama_config: Optional[ExllamaConfig] = None,
+ xft_config: Optional[XftConfig] = None,
+ stream_interval: int = 2,
+ conv_template: Optional[str] = None,
+ embed_in_truncate: bool = False,
+ seed: Optional[int] = None,
+ debug: bool = False,
+ **kwargs,
+ ):
+ super().__init__(
+ controller_addr,
+ worker_addr,
+ worker_id,
+ model_path,
+ model_names,
+ limit_worker_concurrency,
+ conv_template=conv_template,
+ )
+
+ logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
+ self.model, self.tokenizer = load_model(
+ model_path,
+ device=device,
+ num_gpus=num_gpus,
+ max_gpu_memory=max_gpu_memory,
+ dtype=dtype,
+ load_8bit=load_8bit,
+ cpu_offloading=cpu_offloading,
+ gptq_config=gptq_config,
+ awq_config=awq_config,
+ exllama_config=exllama_config,
+ xft_config=xft_config,
+ debug=debug,
+ model_name=model_names[0],
+ )
+ self.device = device
+ if self.tokenizer.pad_token == None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.context_len = get_context_length(self.model.config)
+ self.generate_stream_func = get_generate_stream_function(self.model, model_path)
+ self.stream_interval = stream_interval
+ self.embed_in_truncate = embed_in_truncate
+ self.seed = seed
+
+ if not no_register:
+ self.init_heart_beat()
+
+ def generate_stream_gate(self, params):
+ self.call_ct += 1
+
+ try:
+ if self.seed is not None:
+ set_seed(self.seed)
+ for output in self.generate_stream_func(
+ self.model,
+ self.tokenizer,
+ params,
+ self.device,
+ self.context_len,
+ self.stream_interval,
+ ):
+ ret = {
+ "text": output["text"],
+ "error_code": 0,
+ }
+ if "usage" in output:
+ ret["usage"] = output["usage"]
+ if "finish_reason" in output:
+ ret["finish_reason"] = output["finish_reason"]
+ if "logprobs" in output:
+ ret["logprobs"] = output["logprobs"]
+ yield json.dumps(ret).encode() + b"\0"
+ except torch.cuda.OutOfMemoryError as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except (ValueError, RuntimeError) as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.INTERNAL_ERROR,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+ def generate_gate(self, params):
+ for x in self.generate_stream_gate(params):
+ pass
+ return json.loads(x[:-1].decode())
+
+ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
+ if model_type_dict.get("is_bert"):
+ model_output = self.model(input_ids)
+ if model_type_dict.get("is_robert"):
+ data = model_output.last_hidden_state
+ else:
+ data = model_output[0]
+ elif model_type_dict.get("is_t5"):
+ model_output = self.model(input_ids, decoder_input_ids=input_ids)
+ data = model_output.encoder_last_hidden_state
+ else:
+ model_output = self.model(input_ids, output_hidden_states=True)
+ if model_type_dict.get("is_chatglm"):
+ data = model_output.hidden_states[-1].transpose(0, 1)
+ else:
+ data = model_output.hidden_states[-1]
+ mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
+ masked_embeddings = data * mask
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
+ token_num = torch.sum(attention_mask).item()
+
+ return sum_embeddings, token_num
+
+ def __encode_base64(self, embeddings: torch.Tensor) -> List[str]:
+ embeddings = embeddings.cpu()
+ return [
+ base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings
+ ]
+
+ @torch.inference_mode()
+ def get_embeddings(self, params):
+ self.call_ct += 1
+
+ try:
+ tokenizer = self.tokenizer
+ ret = {"embedding": [], "token_num": 0}
+
+ model_type_dict = {
+ "is_llama": "llama" in str(type(self.model)),
+ "is_t5": "t5" in str(type(self.model)),
+ "is_chatglm": "chatglm" in str(type(self.model)),
+ "is_bert": "bert" in str(type(self.model)),
+ "is_robert": "robert" in str(type(self.model)),
+ }
+
+ if self.embed_in_truncate:
+ encoding = tokenizer.batch_encode_plus(
+ params["input"],
+ padding=True,
+ truncation="longest_first",
+ return_tensors="pt",
+ max_length=self.context_len,
+ )
+ else:
+ encoding = tokenizer.batch_encode_plus(
+ params["input"], padding=True, return_tensors="pt"
+ )
+ input_ids = encoding["input_ids"].to(self.device)
+ attention_mask = input_ids != tokenizer.pad_token_id
+
+ base64_encode = params.get("encoding_format", None)
+
+ if self.embed_in_truncate:
+ chunk_embeddings, token_num = self.__process_embed_chunk(
+ input_ids, attention_mask, **model_type_dict
+ )
+ embedding = chunk_embeddings / token_num
+ normalized_embeddings = F.normalize(embedding, p=2, dim=1)
+ ret["token_num"] = token_num
+ else:
+ all_embeddings = []
+ all_token_num = 0
+ for i in range(0, input_ids.size(1), self.context_len):
+ chunk_input_ids = input_ids[:, i : i + self.context_len]
+ chunk_attention_mask = attention_mask[:, i : i + self.context_len]
+
+ chunk_embeddings, token_num = self.__process_embed_chunk(
+ chunk_input_ids, chunk_attention_mask, **model_type_dict
+ )
+ all_embeddings.append(chunk_embeddings)
+ all_token_num += token_num
+
+ all_embeddings_tensor = torch.stack(all_embeddings)
+ embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
+ normalized_embeddings = F.normalize(embedding, p=2, dim=1)
+
+ ret["token_num"] = all_token_num
+
+ if base64_encode == "base64":
+ out_embeddings = self.__encode_base64(normalized_embeddings)
+ else:
+ out_embeddings = normalized_embeddings.tolist()
+ ret["embedding"] = out_embeddings
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ if self.device == "xpu":
+ torch.xpu.empty_cache()
+ if self.device == "npu":
+ torch.npu.empty_cache()
+ except torch.cuda.OutOfMemoryError as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
+ }
+ except (ValueError, RuntimeError) as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.INTERNAL_ERROR,
+ }
+ return ret
+
+
+def create_model_worker():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ add_model_args(parser)
+ parser.add_argument(
+ "--model-names",
+ type=lambda s: s.split(","),
+ help="Optional display comma separated names",
+ )
+ parser.add_argument(
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
+ )
+ parser.add_argument("--embed-in-truncate", action="store_true")
+ parser.add_argument(
+ "--limit-worker-concurrency",
+ type=int,
+ default=5,
+ help="Limit the model concurrency to prevent OOM.",
+ )
+ parser.add_argument("--stream-interval", type=int, default=2)
+ parser.add_argument("--no-register", action="store_true")
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Overwrite the random seed for each generation.",
+ )
+ parser.add_argument(
+ "--debug", type=bool, default=False, help="Print debugging messages"
+ )
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ if args.gpus:
+ if len(args.gpus.split(",")) < args.num_gpus:
+ raise ValueError(
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
+ )
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
+
+ gptq_config = GptqConfig(
+ ckpt=args.gptq_ckpt or args.model_path,
+ wbits=args.gptq_wbits,
+ groupsize=args.gptq_groupsize,
+ act_order=args.gptq_act_order,
+ )
+ awq_config = AWQConfig(
+ ckpt=args.awq_ckpt or args.model_path,
+ wbits=args.awq_wbits,
+ groupsize=args.awq_groupsize,
+ )
+ if args.enable_exllama:
+ exllama_config = ExllamaConfig(
+ max_seq_len=args.exllama_max_seq_len,
+ gpu_split=args.exllama_gpu_split,
+ )
+ else:
+ exllama_config = None
+ if args.enable_xft:
+ xft_config = XftConfig(
+ max_seq_len=args.xft_max_seq_len,
+ data_type=args.xft_dtype,
+ )
+ if args.device != "cpu":
+ print("xFasterTransformer now is only support CPUs. Reset device to CPU")
+ args.device = "cpu"
+ else:
+ xft_config = None
+
+ worker = ModelWorker(
+ args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.model_path,
+ args.model_names,
+ args.limit_worker_concurrency,
+ no_register=args.no_register,
+ device=args.device,
+ num_gpus=args.num_gpus,
+ max_gpu_memory=args.max_gpu_memory,
+ dtype=str_to_torch_dtype(args.dtype),
+ load_8bit=args.load_8bit,
+ cpu_offloading=args.cpu_offloading,
+ gptq_config=gptq_config,
+ awq_config=awq_config,
+ exllama_config=exllama_config,
+ xft_config=xft_config,
+ stream_interval=args.stream_interval,
+ conv_template=args.conv_template,
+ embed_in_truncate=args.embed_in_truncate,
+ seed=args.seed,
+ debug=args.debug,
+ )
+ return args, worker
+
+
+if __name__ == "__main__":
+ args, worker = create_model_worker()
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/fastchat/serve/monitor/basic_stats.py b/fastchat/serve/monitor/basic_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1934bb07863ddedb4b5dedba0b3e4724c78a765
--- /dev/null
+++ b/fastchat/serve/monitor/basic_stats.py
@@ -0,0 +1,210 @@
+import argparse
+import code
+import datetime
+import json
+import os
+from pytz import timezone
+import time
+
+import pandas as pd # pandas>=2.0.3
+import plotly.express as px
+import plotly.graph_objects as go
+from tqdm import tqdm
+
+
+NUM_SERVERS = 14
+
+
+def get_log_files(max_num_files=None):
+ dates = []
+ for month in range(4, 12):
+ for day in range(1, 33):
+ dates.append(f"2023-{month:02d}-{day:02d}")
+
+ filenames = []
+ for d in dates:
+ for i in range(NUM_SERVERS):
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
+ if os.path.exists(name):
+ filenames.append(name)
+ max_num_files = max_num_files or len(filenames)
+ filenames = filenames[-max_num_files:]
+ return filenames
+
+
+def load_log_files(log_files):
+ data = []
+ for filename in tqdm(log_files, desc="read files"):
+ for retry in range(5):
+ try:
+ lines = open(filename).readlines()
+ break
+ except FileNotFoundError:
+ time.sleep(2)
+
+ for l in lines:
+ row = json.loads(l)
+
+ data.append(
+ dict(
+ type=row["type"],
+ tstamp=row["tstamp"],
+ model=row.get("model", ""),
+ models=row.get("models", ["", ""]),
+ )
+ )
+
+ return data
+
+
+def get_anony_vote_df(df):
+ anony_vote_df = df[
+ df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
+ ]
+ anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")]
+ return anony_vote_df
+
+
+def merge_counts(series, on, names):
+ ret = pd.merge(series[0], series[1], on=on)
+ for i in range(2, len(series)):
+ ret = pd.merge(ret, series[i], on=on)
+ ret = ret.reset_index()
+ old_names = list(ret.columns)[-len(series) :]
+ rename = {old_name: new_name for old_name, new_name in zip(old_names, names)}
+ ret = ret.rename(columns=rename)
+ return ret
+
+
+def report_basic_stats(log_files):
+ df_all = load_log_files(log_files)
+ df_all = pd.DataFrame(df_all)
+ now_t = df_all["tstamp"].max()
+ df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)]
+ df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)]
+ anony_vote_df_all = get_anony_vote_df(df_all)
+
+ # Chat trends
+ chat_dates = [
+ datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
+ "%Y-%m-%d"
+ )
+ for x in df_all[df_all["type"] == "chat"]["tstamp"]
+ ]
+ chat_dates_counts = pd.value_counts(chat_dates)
+ vote_dates = [
+ datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
+ "%Y-%m-%d"
+ )
+ for x in anony_vote_df_all["tstamp"]
+ ]
+ vote_dates_counts = pd.value_counts(vote_dates)
+ chat_dates_bar = go.Figure(
+ data=[
+ go.Bar(
+ name="Anony. Vote",
+ x=vote_dates_counts.index,
+ y=vote_dates_counts,
+ text=[f"{val:.0f}" for val in vote_dates_counts],
+ textposition="auto",
+ ),
+ go.Bar(
+ name="Chat",
+ x=chat_dates_counts.index,
+ y=chat_dates_counts,
+ text=[f"{val:.0f}" for val in chat_dates_counts],
+ textposition="auto",
+ ),
+ ]
+ )
+ chat_dates_bar.update_layout(
+ barmode="stack",
+ xaxis_title="Dates",
+ yaxis_title="Count",
+ height=300,
+ width=1200,
+ )
+
+ # Model call counts
+ model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts()
+ model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts()
+ model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts()
+ model_hist = merge_counts(
+ [model_hist_all, model_hist_1_day, model_hist_1_hour],
+ on="model",
+ names=["All", "Last Day", "Last Hour"],
+ )
+ model_hist_md = model_hist.to_markdown(index=False, tablefmt="github")
+
+ # Action counts
+ action_hist_all = df_all["type"].value_counts()
+ action_hist_1_day = df_1_day["type"].value_counts()
+ action_hist_1_hour = df_1_hour["type"].value_counts()
+ action_hist = merge_counts(
+ [action_hist_all, action_hist_1_day, action_hist_1_hour],
+ on="type",
+ names=["All", "Last Day", "Last Hour"],
+ )
+ action_hist_md = action_hist.to_markdown(index=False, tablefmt="github")
+
+ # Anony vote counts
+ anony_vote_hist_all = anony_vote_df_all["type"].value_counts()
+ anony_vote_df_1_day = get_anony_vote_df(df_1_day)
+ anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts()
+ # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour)
+ # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts()
+ anony_vote_hist = merge_counts(
+ [anony_vote_hist_all, anony_vote_hist_1_day],
+ on="type",
+ names=["All", "Last Day"],
+ )
+ anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github")
+
+ # Last 24 hours
+ chat_1_day = df_1_day[df_1_day["type"] == "chat"]
+ num_chats_last_24_hours = []
+ base = df_1_day["tstamp"].min()
+ for i in range(24, 0, -1):
+ left = base + (i - 1) * 3600
+ right = base + i * 3600
+ num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum()
+ num_chats_last_24_hours.append(num)
+ times = [
+ datetime.datetime.fromtimestamp(
+ base + i * 3600, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+ for i in range(24, 0, -1)
+ ]
+ last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours})
+ last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github")
+
+ # Last update datetime
+ last_updated_tstamp = now_t
+ last_updated_datetime = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+
+ # code.interact(local=locals())
+
+ return {
+ "chat_dates_bar": chat_dates_bar,
+ "model_hist_md": model_hist_md,
+ "action_hist_md": action_hist_md,
+ "anony_vote_hist_md": anony_vote_hist_md,
+ "num_chats_last_24_hours": last_24_hours_md,
+ "last_updated_datetime": last_updated_datetime,
+ }
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--max-num-files", type=int)
+ args = parser.parse_args()
+
+ log_files = get_log_files(args.max_num_files)
+ basic_stats = report_basic_stats(log_files)
+
+ print(basic_stats["action_hist_md"] + "\n")
+ print(basic_stats["model_hist_md"] + "\n")
+ print(basic_stats["anony_vote_hist_md"] + "\n")
+ print(basic_stats["num_chats_last_24_hours"] + "\n")
diff --git a/fastchat/serve/monitor/clean_battle_data.py b/fastchat/serve/monitor/clean_battle_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..23357d08cd2b24ca2bbecdd5c1434d1b7203a2f9
--- /dev/null
+++ b/fastchat/serve/monitor/clean_battle_data.py
@@ -0,0 +1,269 @@
+"""
+Clean chatbot arena battle log.
+
+Usage:
+python3 clean_battle_data.py --mode conv_release
+"""
+import argparse
+import datetime
+import json
+import os
+from pytz import timezone
+import time
+
+from tqdm import tqdm
+
+from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS
+from fastchat.utils import detect_language
+
+
+VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
+IDENTITY_WORDS = [
+ "vicuna",
+ "lmsys",
+ "koala",
+ "uc berkeley",
+ "open assistant",
+ "laion",
+ "chatglm",
+ "chatgpt",
+ "openai",
+ "anthropic",
+ "claude",
+ "bard",
+ "palm",
+ "lamda",
+ "google",
+ "llama",
+ "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
+ "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
+]
+
+for i in range(len(IDENTITY_WORDS)):
+ IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
+
+
+def get_log_files(max_num_files=None):
+ dates = []
+ for month in range(4, 12):
+ for day in range(1, 33):
+ dates.append(f"2023-{month:02d}-{day:02d}")
+
+ filenames = []
+ for d in dates:
+ for i in range(NUM_SERVERS):
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
+ if os.path.exists(name):
+ filenames.append(name)
+ max_num_files = max_num_files or len(filenames)
+ filenames = filenames[-max_num_files:]
+ return filenames
+
+
+def remove_html(raw):
+ if raw.startswith(""):
+ return raw[raw.find(": ") + 2 : -len("
\n")]
+ return raw
+
+
+def to_openai_format(messages):
+ roles = ["user", "assistant"]
+ ret = []
+ for i, x in enumerate(messages):
+ ret.append({"role": roles[i % 2], "content": x[1]})
+ return ret
+
+
+def replace_model_name(old_name):
+ return (
+ old_name.replace("bard", "palm-2")
+ .replace("claude-v1", "claude-1")
+ .replace("claude-instant-v1", "claude-instant-1")
+ .replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b")
+ )
+
+
+def clean_battle_data(log_files, exclude_model_names):
+ data = []
+ for filename in tqdm(log_files, desc="read files"):
+ for retry in range(5):
+ try:
+ lines = open(filename).readlines()
+ break
+ except FileNotFoundError:
+ time.sleep(2)
+
+ for l in lines:
+ row = json.loads(l)
+ if row["type"] in VOTES:
+ data.append(row)
+
+ convert_type = {
+ "leftvote": "model_a",
+ "rightvote": "model_b",
+ "tievote": "tie",
+ "bothbad_vote": "tie (bothbad)",
+ }
+
+ all_models = set()
+ all_ips = dict()
+ ct_anony = 0
+ ct_invalid = 0
+ ct_leaked_identity = 0
+ battles = []
+ for row in data:
+ if row["models"][0] is None or row["models"][1] is None:
+ continue
+
+ # Resolve model names
+ models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
+ if "model_name" in row["states"][0]:
+ models_hidden = [
+ row["states"][0]["model_name"],
+ row["states"][1]["model_name"],
+ ]
+ if models_hidden[0] is None:
+ models_hidden = models_public
+ else:
+ models_hidden = models_public
+
+ if (models_public[0] == "" and models_public[1] != "") or (
+ models_public[1] == "" and models_public[0] != ""
+ ):
+ ct_invalid += 1
+ continue
+
+ if models_public[0] == "" or models_public[0] == "Model A":
+ anony = True
+ models = models_hidden
+ ct_anony += 1
+ else:
+ anony = False
+ models = models_public
+ if not models_public == models_hidden:
+ ct_invalid += 1
+ continue
+
+ # Detect langauge
+ state = row["states"][0]
+ if state["offset"] >= len(state["messages"]):
+ ct_invalid += 1
+ continue
+ lang_code = detect_language(state["messages"][state["offset"]][1])
+
+ # Drop conversations if the model names are leaked
+ leaked_identity = False
+ messages = ""
+ for i in range(2):
+ state = row["states"][i]
+ for role, msg in state["messages"][state["offset"] :]:
+ if msg:
+ messages += msg.lower()
+ for word in IDENTITY_WORDS:
+ if word in messages:
+ leaked_identity = True
+ break
+
+ if leaked_identity:
+ ct_leaked_identity += 1
+ continue
+
+ # Replace bard with palm
+ models = [replace_model_name(m) for m in models]
+
+ # Exclude certain models
+ if any(x in exclude_model_names for x in models):
+ ct_invalid += 1
+ continue
+
+ question_id = row["states"][0]["conv_id"]
+ conversation_a = to_openai_format(
+ row["states"][0]["messages"][row["states"][0]["offset"] :]
+ )
+ conversation_b = to_openai_format(
+ row["states"][1]["messages"][row["states"][1]["offset"] :]
+ )
+
+ ip = row["ip"]
+ if ip not in all_ips:
+ all_ips[ip] = len(all_ips)
+ user_id = all_ips[ip]
+
+ # Save the results
+ battles.append(
+ dict(
+ question_id=question_id,
+ model_a=models[0],
+ model_b=models[1],
+ winner=convert_type[row["type"]],
+ judge=f"arena_user_{user_id}",
+ conversation_a=conversation_a,
+ conversation_b=conversation_b,
+ turn=len(conversation_a) // 2,
+ anony=anony,
+ language=lang_code,
+ tstamp=row["tstamp"],
+ )
+ )
+
+ all_models.update(models_hidden)
+ battles.sort(key=lambda x: x["tstamp"])
+ last_updated_tstamp = battles[-1]["tstamp"]
+
+ last_updated_datetime = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+
+ print(
+ f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
+ f"#leaked_identity: {ct_leaked_identity}"
+ )
+ print(f"#battles: {len(battles)}, #anony: {ct_anony}")
+ print(f"#models: {len(all_models)}, {all_models}")
+ print(f"last-updated: {last_updated_datetime}")
+
+ return battles
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--max-num-files", type=int)
+ parser.add_argument(
+ "--mode", type=str, choices=["simple", "conv_release"], default="simple"
+ )
+ parser.add_argument("--exclude-model-names", type=str, nargs="+")
+ args = parser.parse_args()
+
+ log_files = get_log_files(args.max_num_files)
+ battles = clean_battle_data(log_files, args.exclude_model_names or [])
+ last_updated_tstamp = battles[-1]["tstamp"]
+ cutoff_date = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y%m%d")
+
+ if args.mode == "simple":
+ for x in battles:
+ for key in [
+ "conversation_a",
+ "conversation_b",
+ "question_id",
+ ]:
+ del x[key]
+ print("Samples:")
+ for i in range(4):
+ print(battles[i])
+ output = f"clean_battle_{cutoff_date}.json"
+ elif args.mode == "conv_release":
+ new_battles = []
+ for x in battles:
+ if not x["anony"]:
+ continue
+ for key in []:
+ del x[key]
+ new_battles.append(x)
+ battles = new_battles
+ output = f"clean_battle_conv_{cutoff_date}.json"
+
+ with open(output, "w") as fout:
+ json.dump(battles, fout, indent=2, ensure_ascii=False)
+ print(f"Write cleaned data to {output}")
diff --git a/fastchat/serve/monitor/clean_chat_data.py b/fastchat/serve/monitor/clean_chat_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f0c9bd4fa4cce17af03597dab12a9cdcc9453c5
--- /dev/null
+++ b/fastchat/serve/monitor/clean_chat_data.py
@@ -0,0 +1,171 @@
+"""
+Clean chatbot arena chat log.
+
+Usage:
+python3 clean_chat_data.py --mode conv_release
+"""
+import argparse
+import datetime
+import json
+import os
+from pytz import timezone
+import time
+
+from tqdm import tqdm
+
+from fastchat.serve.monitor.basic_stats import NUM_SERVERS
+from fastchat.serve.monitor.clean_battle_data import (
+ to_openai_format,
+ replace_model_name,
+)
+from fastchat.utils import detect_language
+
+
+NETWORK_ERROR_MSG = (
+ "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower()
+)
+
+
+def get_log_files(max_num_files=None):
+ dates = []
+ for month in range(4, 12):
+ for day in range(1, 33):
+ dates.append(f"2023-{month:02d}-{day:02d}")
+
+ filenames = []
+ for d in dates:
+ for i in range(NUM_SERVERS):
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
+ if os.path.exists(name):
+ filenames.append(name)
+ max_num_files = max_num_files or len(filenames)
+ # filenames = list(reversed(filenames))
+ filenames = filenames[-max_num_files:]
+ return filenames
+
+
+def clean_chat_data(log_files, action_type):
+ raw_data = []
+ for filename in tqdm(log_files, desc="read files"):
+ for retry in range(5):
+ try:
+ lines = open(filename).readlines()
+ break
+ except FileNotFoundError:
+ time.sleep(2)
+
+ for l in lines:
+ row = json.loads(l)
+ if row["type"] == action_type:
+ raw_data.append(row)
+
+ all_models = set()
+ all_ips = dict()
+ chats = []
+ ct_invalid_conv_id = 0
+ ct_invalid = 0
+ ct_network_error = 0
+ for row in raw_data:
+ try:
+ if action_type in ["chat", "upvote", "downvote"]:
+ state = row["state"]
+ model = row["model"]
+ elif action_type == "leftvote":
+ state = row["states"][0]
+ model = row["states"][0]["model_name"]
+ elif action_type == "rightvote":
+ state = row["states"][1]
+ model = row["states"][1]["model_name"]
+ conversation_id = state["conv_id"]
+ except KeyError:
+ ct_invalid_conv_id += 1
+ continue
+
+ if conversation_id is None:
+ ct_invalid_conv_id += 1
+ continue
+
+ conversation = to_openai_format(state["messages"][state["offset"] :])
+ if not isinstance(model, str):
+ ct_invalid += 1
+ continue
+ model = replace_model_name(model)
+
+ try:
+ lang_code = detect_language(state["messages"][state["offset"]][1])
+ except IndexError:
+ ct_invalid += 1
+ continue
+
+ if not all(isinstance(x["content"], str) for x in conversation):
+ ct_invalid += 1
+ continue
+
+ messages = "".join([x["content"] for x in conversation]).lower()
+ if NETWORK_ERROR_MSG in messages:
+ ct_network_error += 1
+ continue
+
+ ip = row["ip"]
+ if ip not in all_ips:
+ all_ips[ip] = len(all_ips)
+ user_id = all_ips[ip]
+
+ chats.append(
+ dict(
+ conversation_id=conversation_id,
+ model=model,
+ conversation=conversation,
+ turn=len(conversation) // 2,
+ language=lang_code,
+ user_id=user_id,
+ tstamp=row["tstamp"],
+ )
+ )
+
+ all_models.update([model])
+
+ chats.sort(key=lambda x: x["tstamp"])
+ last_updated_tstamp = chats[-1]["tstamp"]
+ last_updated_datetime = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+
+ # Deduplication
+ dedup_chats = []
+ visited_conv_ids = set()
+ for i in reversed(range(len(chats))):
+ if chats[i]["conversation_id"] in visited_conv_ids:
+ continue
+ visited_conv_ids.add(chats[i]["conversation_id"])
+ dedup_chats.append(chats[i])
+
+ print(
+ f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}"
+ )
+ print(
+ f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}"
+ )
+ print(f"#models: {len(all_models)}, {all_models}")
+ print(f"last-updated: {last_updated_datetime}")
+
+ return list(reversed(dedup_chats))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--action-type", type=str, default="chat")
+ parser.add_argument("--max-num-files", type=int)
+ args = parser.parse_args()
+
+ log_files = get_log_files(args.max_num_files)
+ chats = clean_chat_data(log_files, args.action_type)
+ last_updated_tstamp = chats[-1]["tstamp"]
+ cutoff_date = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y%m%d")
+
+ output = f"clean_{args.action_type}_conv_{cutoff_date}.json"
+ with open(output, "w") as fout:
+ json.dump(chats, fout, indent=2, ensure_ascii=False)
+ print(f"Write cleaned data to {output}")
diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e94cf2756203f207e82cc7f31ff544ecdcc80f0
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py
@@ -0,0 +1,25 @@
+"""Count the unique users in a battle log file."""
+
+import argparse
+import json
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", type=str)
+ args = parser.parse_args()
+
+ lines = json.load(open(args.input))
+ ct_anony_votes = 0
+ all_users = set()
+ all_models = set()
+ for l in lines:
+ if not l["anony"]:
+ continue
+ all_users.add(l["judge"])
+ all_models.add(l["model_a"])
+ all_models.add(l["model_b"])
+ ct_anony_votes += 1
+
+ print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}")
+ print(f"#model: {len(all_models)}")
diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d12d7c652bc02bb7b5c9f65bce0e1644f739c1b
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py
@@ -0,0 +1,155 @@
+"""
+Filter conversations for release.
+
+Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json
+"""
+import argparse
+from collections import defaultdict
+from enum import Enum, auto
+import json
+import os
+import random
+
+from tqdm import tqdm
+
+BLOCKED_WORDS_FILENAME = "blocked_words.json"
+blocked_words = []
+frequency = defaultdict(lambda: 0)
+
+
+class TypeCode(Enum):
+ CORRECT = auto()
+ ANONYMIZED = auto()
+ REDACTED = auto()
+ BAD_FORMAT = auto()
+ BLOCKED_WORD = auto()
+ BLOCKED_MODEL = auto()
+ TOO_SHORT = auto()
+ TOO_FREQUENT = auto()
+
+
+def detect_type(conv):
+ for key in ["conversation_a", "conversation_b"]:
+ messages = [row["content"] for row in conv[key]]
+ for msg in messages:
+ if not isinstance(msg, str):
+ return TypeCode.BAD_FORMAT
+
+ user_prompts = [
+ row["content"].lower().strip() for row in conv[key] if row["role"] == "user"
+ ]
+ if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts):
+ return TypeCode.TOO_SHORT
+
+ if all(x in frequent_prompts for x in user_prompts):
+ return TypeCode.TOO_FREQUENT
+
+ for msg in messages:
+ msg = msg.lower()
+ if "" in msg:
+ return TypeCode.ANONYMIZED
+ if "" in msg:
+ return TypeCode.REDACTED
+
+ for w in blocked_words:
+ if w in msg:
+ return TypeCode.BLOCKED_WORD
+
+ for key in ["model_a", "model_b"]:
+ if conv[key] in ["vicuna-33b", "mpt-30b-chat"]:
+ return TypeCode.BLOCKED_MODEL
+
+ return TypeCode.CORRECT
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--sample", type=int)
+ args = parser.parse_args()
+
+ # Read conversations
+ convs = json.load(open(args.in_file))
+ print(f"#conv: {len(convs)}")
+
+ # Read blocked words
+ if os.path.exists(BLOCKED_WORDS_FILENAME):
+ blocked_words = json.load(open(BLOCKED_WORDS_FILENAME))
+
+ # Count frequency
+ for conv in convs:
+ for key in ["conversation_a", "conversation_b"]:
+ messages = [row["content"] for row in conv[key] if row["role"] == "user"]
+ for msg in messages:
+ if not isinstance(msg, str):
+ continue
+ msg = msg.lower().strip()
+ frequency[msg] += 1
+
+ keys = list(frequency.keys())
+ keys.sort(key=lambda x: -frequency[x])
+ frequent_prompts = keys[:10]
+ frequent_prompts = set(frequent_prompts)
+ frequent_prompts.add("")
+
+ # Start filter
+ ct_bad_format = 0
+ ct_anonymized = 0
+ ct_redacted = 0
+ ct_error = 0
+ ct_lang_filter = 0
+ ct_flagged = 0
+ ct_blocked_word = 0
+ ct_blocked_model = 0
+ ct_too_short = 0
+ ct_too_frequent = 0
+
+ new_convs = []
+ for conv in tqdm(convs):
+ type_code = detect_type(conv)
+
+ if type_code == TypeCode.BAD_FORMAT:
+ ct_bad_format += 1
+ continue
+
+ if type_code == TypeCode.ANONYMIZED:
+ ct_anonymized += 1
+ continue
+ elif type_code == TypeCode.REDACTED:
+ ct_redacted += 1
+ continue
+ elif type_code == TypeCode.BLOCKED_WORD:
+ ct_blocked_word += 1
+ continue
+ elif type_code == TypeCode.BLOCKED_MODEL:
+ ct_blocked_model += 1
+ continue
+ elif type_code == TypeCode.TOO_SHORT:
+ ct_too_short += 1
+ continue
+ elif type_code == TypeCode.TOO_FREQUENT:
+ ct_too_frequent += 1
+ continue
+
+ if conv["openai_moderation"]["flagged"]:
+ ct_flagged += 1
+ continue
+
+ if type_code in [TypeCode.CORRECT]:
+ new_convs.append(conv)
+
+ if args.sample:
+ # random.seed(0)
+ # random.shuffle(new_convs)
+ new_convs = new_convs[: args.sample]
+
+ print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}")
+ print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}")
+ print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}")
+ print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}")
+ print(f"new_conv: {len(new_convs)}")
+
+ out_file = args.in_file.replace(".json", ".out.json")
+ print(f"Output to {out_file}")
+ with open(out_file, "w") as fout:
+ json.dump(new_convs, fout, indent=2, ensure_ascii=False)
diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a88209bfcb58cb2131ce94d6eba03c899e74a0a
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py
@@ -0,0 +1,25 @@
+"""Count the unique users in a battle log file."""
+
+import argparse
+import json
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", type=str)
+ parser.add_argument("--tag-file", type=str)
+ args = parser.parse_args()
+
+ # build index
+ objs = json.load(open(args.tag_file))
+ new_field_dict = {}
+ for obj in objs:
+ new_field_dict[obj["question_id"]] = obj["toxic_chat"]
+
+ objs = json.load(open(args.input))
+ for obj in objs:
+ obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]]
+
+ output = args.input.replace(".json", "_added.json")
+ with open(output, "w") as fout:
+ json.dump(objs, fout, indent=2, ensure_ascii=False)
diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cd78b71e95a3034bf3440aee3557a38426d0244
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py
@@ -0,0 +1,32 @@
+"""
+Count the unique users in a battle log file.
+
+Usage:
+python3 -input in.json --number 1000
+"""
+
+import argparse
+import json
+import random
+
+K = 1000
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", type=str)
+ parser.add_argument("--number", type=int, nargs="+")
+ args = parser.parse_args()
+
+ convs = json.load(open(args.input))
+ random.seed(0)
+ random.shuffle(convs)
+
+ for number in args.number:
+ new_convs = convs[:number]
+
+ output = args.input.replace(".json", f"_{number//K}k.json")
+ with open(output, "w") as fout:
+ json.dump(new_convs, fout, indent=2, ensure_ascii=False)
+
+ print(f"#in: {len(convs)}, #out: {len(new_convs)}")
+ print(f"Write to file: {output}")
diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e37aadcea65df7ca605369b88c068aa57c8f35f2
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py
@@ -0,0 +1,9 @@
+"""
+Upload to huggingface.
+"""
+import json
+from datasets import Dataset, DatasetDict, load_dataset
+
+objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json"))
+data = Dataset.from_list(objs)
+data.push_to_hub("lmsys/chatbot_arena_conversations", private=True)
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7084207309907dcb8fa37eccf55fd2a6b62ca48
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py
@@ -0,0 +1,13 @@
+import requests
+
+headers = {"authorization": "Bearer hf_XXX"}
+
+url = "https://huggingface.co./api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending"
+a = requests.get(url, headers=headers)
+
+for u in a.json():
+ user = u["user"]["user"]
+ url = "https://huggingface.co./api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant"
+ ret = requests.post(url, headers=headers, json={"user": user})
+ print(user, ret.status_code)
+ assert ret.status_code == 200
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..97abaaa0df053c93c3adb655f1b5c41af0aab00d
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py
@@ -0,0 +1,119 @@
+"""
+From colab:
+https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing
+"""
+import argparse
+import datetime
+import json
+import os
+from pytz import timezone
+import time
+
+import kaleido
+import numpy as np
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+from tqdm import tqdm
+
+import plotly.io as pio
+
+pio.kaleido.scope.mathjax = None
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--in-file", type=str, required=True)
+parser.add_argument("--scale", type=int, required=True)
+args = parser.parse_args()
+
+filename = args.in_file
+scale = args.scale
+convs = json.load(open(filename))
+df = pd.DataFrame(convs)
+df
+
+print(f"#ips: {df['user_id'].nunique() * scale}")
+print(f"#models: {df['model'].nunique()}")
+print(f"#language: {df['language'].nunique()}")
+print(f"#turns: {df['turn'].mean()}")
+
+model_counts = df["model"].value_counts() * scale
+# print("model counts", model_counts)
+fig = px.bar(x=model_counts.index, y=model_counts)
+fig.update_layout(
+ xaxis_title=None,
+ yaxis_title="Count",
+ height=200,
+ width=950,
+ margin=dict(l=0, r=0, t=0, b=0),
+)
+fig.show()
+fig.write_image("model_count.pdf")
+
+
+model_counts = df["language"].value_counts().head(25) * scale
+fig = px.bar(x=model_counts.index, y=model_counts)
+fig.update_layout(
+ xaxis_title=None,
+ yaxis_title="Count",
+ height=200,
+ width=950,
+ margin=dict(l=0, r=0, t=0, b=0),
+)
+fig.show()
+fig.write_image("language_count.pdf")
+
+chat_dates = [
+ datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d")
+ for x in df["tstamp"]
+]
+
+
+def to_remove(x):
+ for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]:
+ if d in x:
+ return True
+ return False
+
+
+chat_dates = [x for x in chat_dates if not to_remove(x)]
+
+chat_dates_counts = pd.value_counts(chat_dates) * scale
+print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}")
+
+fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts)
+fig.update_layout(
+ xaxis_title="Dates",
+ yaxis_title="Count",
+ height=200,
+ width=950,
+ margin=dict(l=0, r=0, t=0, b=0),
+)
+fig.show()
+fig.write_image("daily_conversation_count.pdf")
+
+import transformers
+
+tokenizer = transformers.AutoTokenizer.from_pretrained(
+ "lmsys/vicuna-7b-v1.5", use_fast=False
+)
+
+prompts = []
+responses = []
+for conv in df["conversation"]:
+ for row in conv:
+ if row["role"] == "user":
+ prompts.append(row["content"])
+ else:
+ responses.append(row["content"])
+
+print(f"#prompts: {len(prompts)}")
+print(f"#responses: {len(responses)}")
+
+
+prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)]
+print()
+print(f"mean prompt len: {np.mean(prompt_lens):.2f}")
+
+response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)]
+print()
+print(f"mean response len: {np.mean(response_lens):.2f}")
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ccde1ca57546acf5d1131cae14a499f1228a02c
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py
@@ -0,0 +1,148 @@
+"""
+Filter conversations for release.
+
+Dependency:
+pip install opencc-python-reimplementedpip install opencc-python-reimplemented
+
+Usage:
+python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json
+"""
+import argparse
+from concurrent.futures import ProcessPoolExecutor
+from collections import defaultdict
+from enum import Enum, auto
+import json
+import os
+import random
+
+from tqdm import tqdm
+import opencc
+
+BLOCKED_WORDS_FILENAME = "blocked_words.json"
+blocked_words = []
+frequency = defaultdict(lambda: 0)
+
+cc_converter = opencc.OpenCC("t2s")
+
+
+class TypeCode(Enum):
+ CORRECT = auto()
+ ANONYMIZED = auto()
+ REDACTED = auto()
+ BAD_FORMAT = auto()
+ BLOCKED_WORD = auto()
+ BLOCKED_MODEL = auto()
+ TOO_SHORT = auto()
+ TOO_FREQUENT = auto()
+
+
+def detect_type(conv):
+ for key in ["conversation_a", "conversation_b", "conversation"]:
+ if key not in conv:
+ continue
+
+ messages = [row["content"] for row in conv[key]]
+ for msg in messages:
+ if not isinstance(msg, str):
+ return TypeCode.BAD_FORMAT
+
+ if len(messages) == 0:
+ return TypeCode.BAD_FORMAT
+
+ user_prompts = [
+ row["content"].lower().strip() for row in conv[key] if row["role"] == "user"
+ ]
+
+ for msg in messages:
+ msg = cc_converter.convert(msg.lower())
+ if "" in msg:
+ return TypeCode.ANONYMIZED
+ if "" in msg:
+ return TypeCode.REDACTED
+
+ for w in blocked_words:
+ if w in msg:
+ return TypeCode.BLOCKED_WORD
+
+ return TypeCode.CORRECT
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--sample", type=int)
+ args = parser.parse_args()
+
+ # Read conversations
+ convs = json.load(open(args.in_file))
+ print(f"#conv: {len(convs)}")
+
+ # Read blocked words
+ if os.path.exists(BLOCKED_WORDS_FILENAME):
+ blocked_words = json.load(open(BLOCKED_WORDS_FILENAME))
+ blocked_words = [cc_converter.convert(w) for w in blocked_words]
+
+ # Start filter
+ ct_bad_format = 0
+ ct_anonymized = 0
+ ct_redacted = 0
+ ct_error = 0
+ ct_lang_filter = 0
+ ct_flagged = 0
+ ct_blocked_word = 0
+ ct_blocked_model = 0
+ ct_too_short = 0
+ ct_too_frequent = 0
+
+ type_codes = []
+ with ProcessPoolExecutor() as executor:
+ for result in tqdm(executor.map(detect_type, convs), total=len(convs)):
+ type_codes.append(result)
+
+ new_convs = []
+ for conv, type_code in zip(convs, type_codes):
+ if type_code == TypeCode.BAD_FORMAT:
+ ct_bad_format += 1
+ continue
+
+ if type_code == TypeCode.ANONYMIZED:
+ ct_anonymized += 1
+ continue
+ elif type_code == TypeCode.REDACTED:
+ ct_redacted += 1
+ continue
+ elif type_code == TypeCode.BLOCKED_WORD:
+ ct_blocked_word += 1
+ continue
+ elif type_code == TypeCode.BLOCKED_MODEL:
+ ct_blocked_model += 1
+ continue
+ elif type_code == TypeCode.TOO_SHORT:
+ ct_too_short += 1
+ continue
+ elif type_code == TypeCode.TOO_FREQUENT:
+ ct_too_frequent += 1
+ continue
+
+ if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]:
+ ct_flagged += 1
+ continue
+
+ if type_code in [TypeCode.CORRECT]:
+ new_convs.append(conv)
+
+ if args.sample:
+ random.seed(42)
+ random.shuffle(new_convs)
+ new_convs = new_convs[: args.sample]
+
+ print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}")
+ print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}")
+ print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}")
+ print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}")
+ print(f"new_conv: {len(new_convs)}")
+
+ out_file = args.in_file.replace(".json", ".s1.json")
+ print(f"Output to {out_file}")
+ with open(out_file, "w") as fout:
+ json.dump(new_convs, fout, indent=2, ensure_ascii=False)
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..e368e92a1dcf260ecb5b175b77e85c6971809a3c
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py
@@ -0,0 +1,27 @@
+import argparse
+import json
+
+from tqdm import tqdm
+import numpy as np
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ args = parser.parse_args()
+
+ # Read conversations
+ convs = json.load(open(args.in_file))
+ print(f"#conv: {len(convs)}")
+
+ # Delete some fileds
+ for c in convs:
+ del c["tstamp"]
+ del c["user_id"]
+
+ # Write
+ print(f"#out conv: {len(convs)}")
+ out_file = args.in_file.replace(".json", ".s2.json")
+ print(f"Output to {out_file}")
+ with open(out_file, "w") as fout:
+ json.dump(convs, fout, indent=2, ensure_ascii=False)
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md
new file mode 100644
index 0000000000000000000000000000000000000000..4c439731f6aee43bd29e1a65576c5ae04ff59cfa
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md
@@ -0,0 +1,23 @@
+```
+export BASE=clean_conv_20230809_100k_pii
+export SCALE=10
+
+# filter words
+python3 filter_bad_conv.py --in $BASE.json
+
+# Clean up some fileds (e.g., timestamps)
+python3 final_post_processing.py --in $BASE.s1.json
+
+# upload to hf
+python3 upload_hf_dataset.py --in $BASE.s1.s2.json
+
+# Make another version with openai moderation tag
+python3 merge_oai_tag.py --in $BASE.s1.s2.json
+
+# Make visualizations
+python3 compute_stats.py --in $BASE.s1.json --scale $SCALE
+
+# Copy figures
+scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" .
+```
+
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py
new file mode 100644
index 0000000000000000000000000000000000000000..18bef5f1962384d80f174aa22a7b6dcc867fe7c0
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py
@@ -0,0 +1,45 @@
+import argparse
+import json
+import time
+
+from tqdm import tqdm
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ parser.add_argument("--sample", type=int)
+ args = parser.parse_args()
+
+ tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json"
+ # tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json"
+ in_file = args.in_file
+ tic = time.time()
+
+ # Load tags
+ print("Load tags...")
+ tag_data = json.load(open(tag_file))
+ tag_dict = {}
+ for c in tqdm(tag_data):
+ tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]]
+ print(f"elapsed: {time.time() - tic:.2f} s")
+
+ # Append to input_file
+ print("Load inputs...")
+ input_data = json.load(open(in_file))
+ for c in tqdm(input_data):
+ cid = c["conversation_id"]
+ if cid in tag_dict:
+ c["openai_moderation"] = tag_dict[cid]
+ else:
+ print(f"missing tag for conv {cid}")
+ exit()
+ print(f"elapsed: {time.time() - tic:.2f} s")
+
+ # Write output
+ print("Write outputs...")
+ out_file = in_file.replace(".json", ".with_tag.json")
+ print(f"Output to {out_file}")
+ with open(out_file, "w") as fout:
+ json.dump(input_data, fout, indent=2, ensure_ascii=False)
+ print(f"elapsed: {time.time() - tic:.2f} s")
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5bae9fbad221c57eba8f2cf5b7eb2779a6f040a8
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh
@@ -0,0 +1,18 @@
+export BASE=clean_conv_20230809_1.5M_pii
+#export BASE=clean_conv_20230809_100k_pii
+export SCALE=1
+
+# Filter words
+python3 filter_bad_conv.py --in $BASE.json --sample 1000000
+
+# Clean up some fileds (e.g., timestamps)
+python3 final_post_processing.py --in $BASE.s1.json
+
+# Upload to hf
+python3 upload_hf_dataset.py --in $BASE.s1.s2.json
+
+# Make another version with openai moderation tag
+python3 merge_oai_tag.py --in $BASE.s1.s2.json
+
+# Make visualizations
+python3 compute_stats.py --in $BASE.s1.json --scale $SCALE
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b6da455fc7bf8af1ce473f80440bff280c9366e
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py
@@ -0,0 +1,32 @@
+"""
+Count the unique users in a battle log file.
+
+Usage:
+python3 -input in.json --number 1000
+"""
+
+import argparse
+import json
+import random
+
+K = 1000
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", type=str)
+ parser.add_argument("--number", type=int, nargs="+")
+ args = parser.parse_args()
+
+ convs = json.load(open(args.input))
+ random.seed(42)
+ random.shuffle(convs)
+
+ for number in args.number:
+ new_convs = convs[:number]
+
+ output = args.input.replace(".json", f"_{number//K}k.json")
+ with open(output, "w") as fout:
+ json.dump(new_convs, fout, indent=2, ensure_ascii=False)
+
+ print(f"#in: {len(convs)}, #out: {len(new_convs)}")
+ print(f"Write to file: {output}")
diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..41d0fbdb59b4c7dc8385bef87a1bf0c8ea6e7401
--- /dev/null
+++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py
@@ -0,0 +1,17 @@
+"""
+Upload to huggingface.
+"""
+import argparse
+import json
+from datasets import Dataset, DatasetDict, load_dataset
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--in-file", type=str, required=True)
+ args = parser.parse_args()
+
+ objs = json.load(open(args.in_file))
+ print(f"#convs: {len(objs)}")
+ data = Dataset.from_list(objs)
+ data.push_to_hub("lmsys/lmsys-chat-1m", private=True)
diff --git a/fastchat/serve/monitor/elo_analysis.py b/fastchat/serve/monitor/elo_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..e95f157c87e31297f7193ef9cecc21d5a90b8b01
--- /dev/null
+++ b/fastchat/serve/monitor/elo_analysis.py
@@ -0,0 +1,303 @@
+import argparse
+from collections import defaultdict
+import datetime
+import json
+import math
+import pickle
+from pytz import timezone
+
+import numpy as np
+import pandas as pd
+import plotly.express as px
+from tqdm import tqdm
+
+from fastchat.model.model_registry import get_model_info
+from fastchat.serve.monitor.basic_stats import get_log_files
+from fastchat.serve.monitor.clean_battle_data import clean_battle_data
+
+
+pd.options.display.float_format = "{:.2f}".format
+
+
+def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000):
+ rating = defaultdict(lambda: INIT_RATING)
+
+ for rd, model_a, model_b, winner in battles[
+ ["model_a", "model_b", "winner"]
+ ].itertuples():
+ ra = rating[model_a]
+ rb = rating[model_b]
+ ea = 1 / (1 + BASE ** ((rb - ra) / SCALE))
+ eb = 1 / (1 + BASE ** ((ra - rb) / SCALE))
+ if winner == "model_a":
+ sa = 1
+ elif winner == "model_b":
+ sa = 0
+ elif winner == "tie" or winner == "tie (bothbad)":
+ sa = 0.5
+ else:
+ raise Exception(f"unexpected vote {winner}")
+ rating[model_a] += K * (sa - ea)
+ rating[model_b] += K * (1 - sa - eb)
+
+ return dict(rating)
+
+
+def get_bootstrap_result(battles, func_compute_elo, num_round=1000):
+ rows = []
+ for i in tqdm(range(num_round), desc="bootstrap"):
+ tmp_battles = battles.sample(frac=1.0, replace=True)
+ rows.append(func_compute_elo(tmp_battles))
+ df = pd.DataFrame(rows)
+ return df[df.median().sort_values(ascending=False).index]
+
+
+def get_median_elo_from_bootstrap(bootstrap_df):
+ median = dict(bootstrap_df.quantile(0.5))
+ median = {k: int(v + 0.5) for k, v in median.items()}
+ return median
+
+
+def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None):
+ # Times each model wins as Model A
+ a_win_ptbl = pd.pivot_table(
+ battles[battles["winner"] == "model_a"],
+ index="model_a",
+ columns="model_b",
+ aggfunc="size",
+ fill_value=0,
+ )
+
+ # Table counting times each model wins as Model B
+ b_win_ptbl = pd.pivot_table(
+ battles[battles["winner"] == "model_b"],
+ index="model_a",
+ columns="model_b",
+ aggfunc="size",
+ fill_value=0,
+ )
+
+ # Table counting number of A-B pairs
+ num_battles_ptbl = pd.pivot_table(
+ battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
+ )
+
+ # Computing the proportion of wins for each model as A and as B
+ # against all other models
+ row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / (
+ num_battles_ptbl + num_battles_ptbl.T
+ )
+
+ if model_order is None:
+ prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False)
+ model_order = list(prop_wins.keys())
+
+ if limit_show_number is not None:
+ model_order = model_order[:limit_show_number]
+
+ # Arrange ordering according to proprition of wins
+ row_beats_col = row_beats_col_freq.loc[model_order, model_order]
+ return row_beats_col
+
+
+def visualize_leaderboard_table(rating):
+ models = list(rating.keys())
+ models.sort(key=lambda k: -rating[k])
+
+ emoji_dict = {
+ 1: "🥇",
+ 2: "🥈",
+ 3: "🥉",
+ }
+
+ md = ""
+ md += "| Rank | Model | Elo Rating | Description |\n"
+ md += "| --- | --- | --- | --- |\n"
+ for i, model in enumerate(models):
+ rank = i + 1
+ minfo = get_model_info(model)
+ emoji = emoji_dict.get(rank, "")
+ md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n"
+
+ return md
+
+
+def visualize_pairwise_win_fraction(battles, model_order):
+ row_beats_col = compute_pairwise_win_fraction(battles, model_order)
+ fig = px.imshow(
+ row_beats_col,
+ color_continuous_scale="RdBu",
+ text_auto=".2f",
+ height=700,
+ width=700,
+ )
+ fig.update_layout(
+ xaxis_title="Model B",
+ yaxis_title="Model A",
+ xaxis_side="top",
+ title_y=0.07,
+ title_x=0.5,
+ )
+ fig.update_traces(
+ hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}"
+ )
+
+ return fig
+
+
+def visualize_battle_count(battles, model_order):
+ ptbl = pd.pivot_table(
+ battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
+ )
+ battle_counts = ptbl + ptbl.T
+ fig = px.imshow(
+ battle_counts.loc[model_order, model_order],
+ text_auto=True,
+ height=700,
+ width=700,
+ )
+ fig.update_layout(
+ xaxis_title="Model B",
+ yaxis_title="Model A",
+ xaxis_side="top",
+ title_y=0.07,
+ title_x=0.5,
+ )
+ fig.update_traces(
+ hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}"
+ )
+ return fig
+
+
+def visualize_average_win_rate(battles, limit_show_number):
+ row_beats_col_freq = compute_pairwise_win_fraction(
+ battles, None, limit_show_number=limit_show_number
+ )
+ fig = px.bar(
+ row_beats_col_freq.mean(axis=1).sort_values(ascending=False),
+ text_auto=".2f",
+ height=500,
+ width=700,
+ )
+ fig.update_layout(
+ yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False
+ )
+ return fig
+
+
+def visualize_bootstrap_elo_rating(df, limit_show_number):
+ bars = (
+ pd.DataFrame(
+ dict(
+ lower=df.quantile(0.025),
+ rating=df.quantile(0.5),
+ upper=df.quantile(0.975),
+ )
+ )
+ .reset_index(names="model")
+ .sort_values("rating", ascending=False)
+ )
+ bars = bars[:limit_show_number]
+ bars["error_y"] = bars["upper"] - bars["rating"]
+ bars["error_y_minus"] = bars["rating"] - bars["lower"]
+ bars["rating_rounded"] = np.round(bars["rating"], 2)
+ fig = px.scatter(
+ bars,
+ x="model",
+ y="rating",
+ error_y="error_y",
+ error_y_minus="error_y_minus",
+ text="rating_rounded",
+ height=500,
+ width=700,
+ )
+ fig.update_layout(xaxis_title="Model", yaxis_title="Rating")
+ return fig
+
+
+def report_elo_analysis_results(battles_json):
+ battles = pd.DataFrame(battles_json)
+ battles = battles.sort_values(ascending=True, by=["tstamp"])
+ # Only use anonymous votes
+ battles = battles[battles["anony"]].reset_index(drop=True)
+ battles_no_ties = battles[~battles["winner"].str.contains("tie")]
+
+ # Online update
+ elo_rating_online = compute_elo(battles)
+
+ # Bootstrap
+ bootstrap_df = get_bootstrap_result(battles, compute_elo)
+ elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df)
+ model_order = list(elo_rating_median.keys())
+ model_order.sort(key=lambda k: -elo_rating_median[k])
+
+ limit_show_number = 25 # limit show number to make plots smaller
+ model_order = model_order[:limit_show_number]
+
+ # Plots
+ leaderboard_table = visualize_leaderboard_table(elo_rating_median)
+ win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order)
+ battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order)
+ average_win_rate_bar = visualize_average_win_rate(
+ battles_no_ties, limit_show_number
+ )
+ bootstrap_elo_rating = visualize_bootstrap_elo_rating(
+ bootstrap_df, limit_show_number
+ )
+
+ last_updated_tstamp = battles["tstamp"].max()
+ last_updated_datetime = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+
+ return {
+ "elo_rating_online": elo_rating_online,
+ "elo_rating_median": elo_rating_median,
+ "leaderboard_table": leaderboard_table,
+ "win_fraction_heatmap": win_fraction_heatmap,
+ "battle_count_heatmap": battle_count_heatmap,
+ "average_win_rate_bar": average_win_rate_bar,
+ "bootstrap_elo_rating": bootstrap_elo_rating,
+ "last_updated_datetime": last_updated_datetime,
+ "last_updated_tstamp": last_updated_tstamp,
+ }
+
+
+def pretty_print_elo_rating(rating):
+ model_order = list(rating.keys())
+ model_order.sort(key=lambda k: -rating[k])
+ for i, model in enumerate(model_order):
+ print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--clean-battle-file", type=str)
+ parser.add_argument("--max-num-files", type=int)
+ args = parser.parse_args()
+
+ np.random.seed(42)
+
+ if args.clean_battle_file:
+ # Read data from a cleaned battle files
+ battles = pd.read_json(args.clean_battle_file)
+ else:
+ # Read data from all log files
+ log_files = get_log_files(args.max_num_files)
+ battles = clean_battle_data(log_files)
+
+ results = report_elo_analysis_results(battles)
+
+ print("# Online")
+ pretty_print_elo_rating(results["elo_rating_online"])
+ print("# Median")
+ pretty_print_elo_rating(results["elo_rating_median"])
+ print(f"last update : {results['last_updated_datetime']}")
+
+ last_updated_tstamp = results["last_updated_tstamp"]
+ cutoff_date = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y%m%d")
+
+ with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout:
+ pickle.dump(results, fout)
diff --git a/fastchat/serve/monitor/inspect_conv.py b/fastchat/serve/monitor/inspect_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..a680a419bd9d11d0db85afbc21c0063a2ae36df7
--- /dev/null
+++ b/fastchat/serve/monitor/inspect_conv.py
@@ -0,0 +1,87 @@
+import argparse
+import code
+import datetime
+import json
+import os
+from pytz import timezone
+import time
+
+import pandas as pd
+from tqdm import tqdm
+
+
+def get_log_files(max_num_files=None):
+ dates = []
+ for month in [4, 5]:
+ for day in range(1, 32):
+ dates.append(f"2023-{month:02d}-{day:02d}")
+
+ num_servers = 14
+ filenames = []
+ for d in dates:
+ for i in range(num_servers):
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
+ if os.path.exists(name):
+ filenames.append(name)
+ max_num_files = max_num_files or len(filenames)
+ filenames = filenames[-max_num_files:]
+ return filenames
+
+
+def pretty_print_conversation(messages):
+ for role, msg in messages:
+ print(f"[[{role}]]: {msg}")
+
+
+def inspect_convs(log_files):
+ data = []
+ for filename in tqdm(log_files, desc="read files"):
+ for retry in range(5):
+ try:
+ lines = open(filename).readlines()
+ break
+ except FileNotFoundError:
+ time.sleep(2)
+
+ for l in lines:
+ row = json.loads(l)
+
+ if "states" not in row:
+ continue
+ if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]:
+ continue
+
+ model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
+ if row["type"] == "leftvote":
+ winner, loser = model_names[0], model_names[1]
+ winner_conv, loser_conv = row["states"][0], row["states"][1]
+ elif row["type"] == "rightvote":
+ loser, winner = model_names[0], model_names[1]
+ loser_conv, winner_conv = row["states"][0], row["states"][1]
+
+ if loser == "bard" and winner == "vicuna-13b":
+ print("=" * 20)
+ print(f"Winner: {winner}")
+ pretty_print_conversation(winner_conv["messages"])
+ print(f"Loser: {loser}")
+ pretty_print_conversation(loser_conv["messages"])
+ print("=" * 20)
+ input()
+
+ # if row["type"] == "bothbad_vote" and "gpt-4" in model_names:
+ # print("=" * 20)
+ # print(f"Model A: {model_names[0]}")
+ # pretty_print_conversation(row["states"][0]["messages"])
+ # print(f"Model B: {model_names[1]}")
+ # pretty_print_conversation(row["states"][1]["messages"])
+ # print("=" * 20)
+ # input()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--max-num-files", type=int)
+ args = parser.parse_args()
+
+ log_files = get_log_files(args.max_num_files)
+ inspect_convs(log_files)
diff --git a/fastchat/serve/monitor/intersect_conv_file.py b/fastchat/serve/monitor/intersect_conv_file.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eadd7cd57510ecbbd23798d55b079c69aac1a12
--- /dev/null
+++ b/fastchat/serve/monitor/intersect_conv_file.py
@@ -0,0 +1,25 @@
+"""
+Take the intersection of two conversation files.
+
+Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json
+"""
+
+import argparse
+import json
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", type=str, required=True)
+ parser.add_argument("--conv-id", type=str, required=True)
+ parser.add_argument("--out-file", type=str, default="intersect.json")
+ args = parser.parse_args()
+
+ conv_id_objs = json.load(open(args.conv_id, "r"))
+ conv_ids = set(x["conversation_id"] for x in conv_id_objs)
+
+ objs = json.load(open(args.input, "r"))
+ after_objs = [x for x in objs if x["conversation_id"] in conv_ids]
+
+ print(f"#in: {len(objs)}, #out: {len(after_objs)}")
+ json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False)
diff --git a/fastchat/serve/monitor/leaderboard_csv_to_html.py b/fastchat/serve/monitor/leaderboard_csv_to_html.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad52e7b2b6e234ed33a51d516e9d682addd1e0eb
--- /dev/null
+++ b/fastchat/serve/monitor/leaderboard_csv_to_html.py
@@ -0,0 +1,51 @@
+"""
+Convert a leaderboard csv file to html table used in the blog.
+
+Usage:
+python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv
+"""
+import argparse
+
+import numpy as np
+
+from fastchat.serve.monitor.monitor import load_leaderboard_table_csv
+
+
+def model_hyperlink(model_name, link):
+ return f' {model_name} '
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", type=str, required=True)
+ args = parser.parse_args()
+
+ data = load_leaderboard_table_csv(args.input, add_hyperlink=False)
+ headers = [
+ "Model",
+ "MT-bench (score)",
+ "Arena Elo rating",
+ "MMLU",
+ "License",
+ ]
+ values = []
+ for item in data:
+ row = []
+ for key in headers:
+ value = item[key]
+ row.append(value)
+ row[0] = model_hyperlink(item["Model"], item["Link"])
+ values.append(row)
+ values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9)
+
+ for value in values:
+ row = ""
+ for x in value:
+ try:
+ if np.isnan(x):
+ x = "-"
+ except TypeError:
+ pass
+ row += f" {x} | "
+ row += "
"
+ print(row)
diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py
new file mode 100644
index 0000000000000000000000000000000000000000..580a2c866ab77e92c1eab61c74ec8b96ce3d30ee
--- /dev/null
+++ b/fastchat/serve/monitor/monitor.py
@@ -0,0 +1,313 @@
+"""
+Live monitor of the website statistics and leaderboard.
+
+Dependency:
+sudo apt install pkg-config libicu-dev
+pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate
+"""
+
+import argparse
+import ast
+import pickle
+import os
+import threading
+import time
+
+import gradio as gr
+import numpy as np
+
+from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files
+from fastchat.serve.monitor.clean_battle_data import clean_battle_data
+from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results
+from fastchat.utils import build_logger, get_window_url_params_js
+
+
+notebook_url = "https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing"
+
+
+basic_component_values = [None] * 6
+leader_component_values = [None] * 5
+
+
+def make_leaderboard_md(elo_results):
+ leaderboard_md = f"""
+# 🏆 Chatbot Arena Leaderboard
+| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
+
+This leaderboard is based on the following three benchmarks.
+- [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings.
+- [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses.
+- [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks.
+
+💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023.
+"""
+ return leaderboard_md
+
+
+def make_leaderboard_md_live(elo_results):
+ leaderboard_md = f"""
+# Leaderboard
+Last updated: {elo_results["last_updated_datetime"]}
+{elo_results["leaderboard_table"]}
+"""
+ return leaderboard_md
+
+
+def update_elo_components(max_num_files, elo_results_file):
+ log_files = get_log_files(max_num_files)
+
+ # Leaderboard
+ if elo_results_file is None: # Do live update
+ battles = clean_battle_data(log_files, [])
+ elo_results = report_elo_analysis_results(battles)
+
+ leader_component_values[0] = make_leaderboard_md_live(elo_results)
+ leader_component_values[1] = elo_results["win_fraction_heatmap"]
+ leader_component_values[2] = elo_results["battle_count_heatmap"]
+ leader_component_values[3] = elo_results["bootstrap_elo_rating"]
+ leader_component_values[4] = elo_results["average_win_rate_bar"]
+
+ # Basic stats
+ basic_stats = report_basic_stats(log_files)
+ md0 = f"Last updated: {basic_stats['last_updated_datetime']}"
+
+ md1 = "### Action Histogram\n"
+ md1 += basic_stats["action_hist_md"] + "\n"
+
+ md2 = "### Anony. Vote Histogram\n"
+ md2 += basic_stats["anony_vote_hist_md"] + "\n"
+
+ md3 = "### Model Call Histogram\n"
+ md3 += basic_stats["model_hist_md"] + "\n"
+
+ md4 = "### Model Call (Last 24 Hours)\n"
+ md4 += basic_stats["num_chats_last_24_hours"] + "\n"
+
+ basic_component_values[0] = md0
+ basic_component_values[1] = basic_stats["chat_dates_bar"]
+ basic_component_values[2] = md1
+ basic_component_values[3] = md2
+ basic_component_values[4] = md3
+ basic_component_values[5] = md4
+
+
+def update_worker(max_num_files, interval, elo_results_file):
+ while True:
+ tic = time.time()
+ update_elo_components(max_num_files, elo_results_file)
+ durtaion = time.time() - tic
+ print(f"update duration: {durtaion:.2f} s")
+ time.sleep(max(interval - durtaion, 0))
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+ return basic_component_values + leader_component_values
+
+
+def model_hyperlink(model_name, link):
+ return f'{model_name}'
+
+
+def load_leaderboard_table_csv(filename, add_hyperlink=True):
+ lines = open(filename).readlines()
+ heads = [v.strip() for v in lines[0].split(",")]
+ rows = []
+ for i in range(1, len(lines)):
+ row = [v.strip() for v in lines[i].split(",")]
+ for j in range(len(heads)):
+ item = {}
+ for h, v in zip(heads, row):
+ if h == "Arena Elo rating":
+ if v != "-":
+ v = int(ast.literal_eval(v))
+ else:
+ v = np.nan
+ elif h == "MMLU":
+ if v != "-":
+ v = round(ast.literal_eval(v) * 100, 1)
+ else:
+ v = np.nan
+ elif h == "MT-bench (win rate %)":
+ if v != "-":
+ v = round(ast.literal_eval(v[:-1]), 1)
+ else:
+ v = np.nan
+ elif h == "MT-bench (score)":
+ if v != "-":
+ v = round(ast.literal_eval(v), 2)
+ else:
+ v = np.nan
+ item[h] = v
+ if add_hyperlink:
+ item["Model"] = model_hyperlink(item["Model"], item["Link"])
+ rows.append(item)
+
+ return rows
+
+
+def build_basic_stats_tab():
+ empty = "Loading ..."
+ basic_component_values[:] = [empty, None, empty, empty, empty, empty]
+
+ md0 = gr.Markdown(empty)
+ gr.Markdown("#### Figure 1: Number of model calls and votes")
+ plot_1 = gr.Plot(show_label=False)
+ with gr.Row():
+ with gr.Column():
+ md1 = gr.Markdown(empty)
+ with gr.Column():
+ md2 = gr.Markdown(empty)
+ with gr.Row():
+ with gr.Column():
+ md3 = gr.Markdown(empty)
+ with gr.Column():
+ md4 = gr.Markdown(empty)
+ return [md0, plot_1, md1, md2, md3, md4]
+
+
+def build_leaderboard_tab(elo_results_file, leaderboard_table_file):
+ if elo_results_file is None: # Do live update
+ md = "Loading ..."
+ p1 = p2 = p3 = p4 = None
+ else:
+ with open(elo_results_file, "rb") as fin:
+ elo_results = pickle.load(fin)
+
+ md = make_leaderboard_md(elo_results)
+ p1 = elo_results["win_fraction_heatmap"]
+ p2 = elo_results["battle_count_heatmap"]
+ p3 = elo_results["bootstrap_elo_rating"]
+ p4 = elo_results["average_win_rate_bar"]
+
+ md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
+
+ if leaderboard_table_file:
+ data = load_leaderboard_table_csv(leaderboard_table_file)
+ headers = [
+ "Model",
+ "Arena Elo rating",
+ "MT-bench (score)",
+ "MMLU",
+ "License",
+ ]
+ values = []
+ for item in data:
+ row = []
+ for key in headers:
+ value = item[key]
+ row.append(value)
+ values.append(row)
+ values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9)
+
+ headers[1] = "⭐ " + headers[1]
+ headers[2] = "📈 " + headers[2]
+
+ gr.Dataframe(
+ headers=headers,
+ datatype=["markdown", "number", "number", "number", "str"],
+ value=values,
+ elem_id="leaderboard_dataframe",
+ )
+ gr.Markdown(
+ """ ## Visit our [HF space](https://huggingface.co./spaces/lmsys/chatbot-arena-leaderboard) for more analysis!
+ If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model).
+ """,
+ elem_id="leaderboard_markdown",
+ )
+ else:
+ pass
+
+ leader_component_values[:] = [md, p1, p2, p3, p4]
+
+ """
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles"
+ )
+ plot_1 = gr.Plot(p1, show_label=False)
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 2: Battle Count for Each Combination of Models (without Ties)"
+ )
+ plot_2 = gr.Plot(p2, show_label=False)
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 3: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)"
+ )
+ plot_3 = gr.Plot(p3, show_label=False)
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 4: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)"
+ )
+ plot_4 = gr.Plot(p4, show_label=False)
+ """
+
+ from fastchat.serve.gradio_web_server import acknowledgment_md
+
+ gr.Markdown(acknowledgment_md)
+
+ # return [md_1, plot_1, plot_2, plot_3, plot_4]
+ return [md_1]
+
+
+def build_demo(elo_results_file, leaderboard_table_file):
+ from fastchat.serve.gradio_web_server import block_css
+
+ text_size = gr.themes.sizes.text_lg
+
+ with gr.Blocks(
+ title="Monitor",
+ theme=gr.themes.Base(text_size=text_size),
+ css=block_css,
+ ) as demo:
+ with gr.Tabs() as tabs:
+ with gr.Tab("Leaderboard", id=0):
+ leader_components = build_leaderboard_tab(
+ elo_results_file, leaderboard_table_file
+ )
+
+ with gr.Tab("Basic Stats", id=1):
+ basic_components = build_basic_stats_tab()
+
+ url_params = gr.JSON(visible=False)
+ demo.load(
+ load_demo,
+ [url_params],
+ basic_components + leader_components,
+ _js=get_window_url_params_js,
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--concurrency-count", type=int, default=10)
+ parser.add_argument("--update-interval", type=int, default=300)
+ parser.add_argument("--max-num-files", type=int)
+ parser.add_argument("--elo-results-file", type=str)
+ parser.add_argument("--leaderboard-table-file", type=str)
+ args = parser.parse_args()
+
+ logger = build_logger("monitor", "monitor.log")
+ logger.info(f"args: {args}")
+
+ if args.elo_results_file is None: # Do live update
+ update_thread = threading.Thread(
+ target=update_worker,
+ args=(args.max_num_files, args.update_interval, args.elo_results_file),
+ )
+ update_thread.start()
+
+ demo = build_demo(args.elo_results_file, args.leaderboard_table_file)
+ demo.queue(
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
+ ).launch(
+ server_name=args.host, server_port=args.port, share=args.share, max_threads=200
+ )
diff --git a/fastchat/serve/monitor/summarize_cluster.py b/fastchat/serve/monitor/summarize_cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d5fbcddc445b2434a35cefda5780f69a6cd8bca
--- /dev/null
+++ b/fastchat/serve/monitor/summarize_cluster.py
@@ -0,0 +1,76 @@
+"""
+Usage:
+python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 --num-prompts 100
+python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model azure-gpt-4-32k --num-prompts 200
+"""
+import argparse
+import pickle
+
+from fastchat.llm_judge.common import (
+ chat_compeletion_openai,
+ chat_compeletion_openai_azure,
+ chat_compeletion_anthropic,
+)
+from fastchat.conversation import get_conv_template
+
+
+def truncate_string(s, l):
+ half = int(l // 2)
+ return s[:half] + s[-half:] if len(s) > l else s
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input-file", type=str, required=True)
+ parser.add_argument("--model", type=str, default="gpt-3.5-turbo")
+ parser.add_argument("--num-prompts", type=int, default=100)
+ args = parser.parse_args()
+
+ model = args.model
+
+ cluster_infos = pickle.load(open(args.input_file, "rb"))
+ num_total_prompts = sum([x[0] for x in cluster_infos])
+
+ topics = []
+ percentages = []
+ for i, info in enumerate(cluster_infos):
+ num_samples, topk_prompts, random_prompts = info
+ percentage = num_samples / num_total_prompts
+ print(
+ f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%"
+ )
+ instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific."
+ split = int(args.num_prompts * 0.8)
+ prompt = "\n".join(
+ [truncate_string(x, l=200) for x in topk_prompts[:split]]
+ + [
+ truncate_string(x, l=200)
+ for x in random_prompts[: args.num_prompts - split]
+ ]
+ )
+ prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST."
+
+ if "azure-" in model:
+ template_name = "chatgpt"
+ completion_func = chat_compeletion_openai_azure
+ elif "gpt" in model:
+ template_name = "chatgpt"
+ completion_func = chat_compeletion_openai
+ elif "claude" in model:
+ template_name = "claude"
+ completion_func = chat_compeletion_anthropic
+
+ conv = get_conv_template(template_name)
+ conv.set_system_message(instruct)
+ conv.append_message(conv.roles[0], prompt)
+ conv.append_message(conv.roles[1], None)
+
+ topic = completion_func(model, conv, temperature=0, max_tokens=256)
+ print(topic)
+
+ topics.append(topic)
+ percentages.append(round(percentage, 6))
+
+ print()
+ print(f"topics: {topics}")
+ print(f"percentages: {percentages}")
diff --git a/fastchat/serve/monitor/tag_openai_moderation.py b/fastchat/serve/monitor/tag_openai_moderation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b80703388b2a47bf372a09bbed81d7bede2bd412
--- /dev/null
+++ b/fastchat/serve/monitor/tag_openai_moderation.py
@@ -0,0 +1,63 @@
+"""
+Add OpenAI moderation API results to all conversations.
+"""
+import argparse
+from concurrent.futures import ThreadPoolExecutor
+import json
+import os
+import time
+
+import openai
+import requests
+from tqdm import tqdm
+
+
+API_MAX_RETRY = 16
+API_RETRY_SLEEP = 10
+API_ERROR_OUTPUT = "$ERROR$"
+
+
+def tag_moderation(text):
+ result = API_ERROR_OUTPUT
+ for _ in range(API_MAX_RETRY):
+ try:
+ result = openai.Moderation.create(input=text)["results"][0]
+ break
+ except openai.error.OpenAIError as e:
+ print(type(e), e)
+ time.sleep(API_RETRY_SLEEP)
+
+ return result
+
+
+def tag_openai_moderation(x):
+ conv = x["conversation_a"]
+ user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"])
+ result = tag_moderation(user_prompts)
+ x["openai_moderation"] = result
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", type=str, required=True)
+ parser.add_argument(
+ "--parallel", type=int, default=1, help="The number of concurrent API calls."
+ )
+ parser.add_argument("--first-n", type=int)
+ args = parser.parse_args()
+
+ battles = json.load(open(args.input))
+
+ if args.first_n:
+ battles = battles[: args.first_n]
+
+ with ThreadPoolExecutor(args.parallel) as executor:
+ for line in tqdm(
+ executor.map(tag_openai_moderation, battles), total=len(battles)
+ ):
+ pass
+
+ output = args.input.replace(".json", "_tagged.json")
+ with open(output, "w") as fout:
+ json.dump(battles, fout, indent=2, ensure_ascii=False)
+ print(f"Write cleaned data to {output}")
diff --git a/fastchat/serve/monitor/topic_clustering.py b/fastchat/serve/monitor/topic_clustering.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd15c6edca6666127f5ee441c41f55bb88878249
--- /dev/null
+++ b/fastchat/serve/monitor/topic_clustering.py
@@ -0,0 +1,267 @@
+"""
+
+Usage:
+python3 topic_clustering.py --in arena.json --english-only --min-length 32
+python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536
+"""
+import argparse
+import json
+import pickle
+import string
+import time
+
+import numpy as np
+from sentence_transformers import SentenceTransformer
+from sentence_transformers.util import cos_sim
+from sklearn.cluster import KMeans, AgglomerativeClustering
+import torch
+from tqdm import tqdm
+
+from fastchat.utils import detect_language
+
+
+def remove_punctuation(input_string):
+ # Make a translator object to remove all punctuation
+ translator = str.maketrans("", "", string.punctuation)
+
+ # Use the translator object to remove the punctuation
+ no_punct = input_string.translate(translator)
+ return no_punct
+
+
+def read_texts(input_file, min_length, max_length, english_only):
+ visited = set()
+ texts = []
+
+ lines = json.load(open(input_file, "r"))
+
+ for l in tqdm(lines):
+ if "text" in l:
+ line_texts = [l["text"]]
+ elif "conversation_a" in l:
+ line_texts = [
+ x["content"] for x in l["conversation_a"] if x["role"] == "user"
+ ]
+ elif "conversation" in l:
+ line_texts = [
+ x["content"] for x in l["conversation"] if x["role"] == "user"
+ ]
+
+ for text in line_texts:
+ text = text.strip()
+
+ # Filter language
+ if english_only:
+ lang = detect_language(text)
+ if lang != "English":
+ continue
+
+ # Filter short or long prompts
+ if min_length:
+ if len(text) < min_length:
+ continue
+
+ if max_length:
+ if len(text) > max_length:
+ continue
+
+ # De-duplication
+ words = sorted([x.lower() for x in remove_punctuation(text).split(" ")])
+ words = "".join(words)
+ if words in visited:
+ continue
+
+ visited.add(words)
+ texts.append(text)
+ return np.array(texts)
+
+
+def get_embeddings(texts, model_name, batch_size):
+ model = SentenceTransformer(model_name)
+ embeddings = model.encode(
+ texts,
+ batch_size=batch_size,
+ show_progress_bar=True,
+ device="cuda",
+ convert_to_tensor=True,
+ )
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
+ return embeddings.cpu()
+
+
+def run_k_means(embeddings, num_clusters):
+ np.random.seed(42)
+ clustering_model = KMeans(n_clusters=num_clusters, n_init="auto")
+ clustering_model.fit(embeddings.numpy())
+ centers = torch.from_numpy(clustering_model.cluster_centers_)
+ labels = torch.from_numpy(clustering_model.labels_)
+
+ # Sort labels
+ classes, counts = np.unique(labels, return_counts=True)
+ indices = np.argsort(counts)[::-1]
+ classes = [classes[i] for i in indices]
+ new_labels = torch.empty_like(labels)
+ new_centers = torch.empty_like(centers)
+ for i, c in enumerate(classes):
+ new_labels[labels == c] = i
+ new_centers[i] = centers[c]
+ return new_centers, new_labels
+
+
+def run_agg_cluster(embeddings, num_clusters):
+ np.random.seed(42)
+ clustering_model = AgglomerativeClustering(n_clusters=num_clusters)
+ clustering_model.fit(embeddings)
+ labels = torch.from_numpy(clustering_model.labels_)
+
+ # Sort labels
+ classes, counts = np.unique(labels, return_counts=True)
+ indices = np.argsort(counts)[::-1]
+ classes = [classes[i] for i in indices]
+ new_labels = torch.empty_like(labels)
+ for i, c in enumerate(classes):
+ new_labels[labels == c] = i
+
+ # Compute centers
+ centers = []
+ for i in range(len(classes)):
+ centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True))
+ centers = torch.cat(centers)
+ return centers, new_labels
+
+
+def run_hdbscan_cluster(embeddings):
+ import hdbscan
+
+ np.random.seed(42)
+ clusterer = hdbscan.HDBSCAN(min_cluster_size=10)
+ labels = torch.from_numpy(clusterer.fit_predict(embeddings))
+
+ # Sort labels
+ classes, counts = np.unique(labels, return_counts=True)
+ indices = np.argsort(counts)[::-1]
+ classes = [classes[i] for i in indices]
+ new_labels = torch.empty_like(labels)
+ for i, c in enumerate(classes):
+ new_labels[labels == c] = i
+
+ # Compute centers
+ centers = []
+ for i in range(len(classes)):
+ centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True))
+ centers = torch.cat(centers)
+ return centers, new_labels
+
+
+def get_topk_indices(centers, labels, embeddings, topk):
+ indices = []
+ arange = torch.arange(len(labels))
+ counts = torch.unique(labels, return_counts=True)[1]
+ topk = min(topk, counts.min().item())
+ for i in range(len(centers)):
+ tmp_indices = labels == i
+ tmp_arange = arange[tmp_indices]
+ tmp_embeddings = embeddings[tmp_indices]
+
+ scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0]
+ sorted_indices = torch.flip(torch.argsort(scores), dims=[0])
+ indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0))
+ return torch.cat(indices)
+
+
+def print_topk(texts, labels, topk_indices, show_cut_off):
+ ret = ""
+ for k in range(len(topk_indices)):
+ num_samples = torch.sum(labels == k).item()
+
+ ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n"
+ for idx in topk_indices[k]:
+ ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n"
+ ret += "=" * 40 + "\n\n"
+
+ return ret
+
+
+def get_cluster_info(texts, labels, topk_indices):
+ np.random.seed(42)
+
+ cluster_info = []
+ for k in range(len(topk_indices)):
+ num_samples = torch.sum(labels == k).item()
+ topk_prompts = []
+ for idx in topk_indices[k]:
+ topk_prompts.append(texts[idx])
+ random_prompts = []
+ for idx in range(len(topk_indices)):
+ random_prompts.append(np.random.choice(texts))
+ cluster_info.append((num_samples, topk_prompts, random_prompts))
+
+ return cluster_info
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input-file", type=str, required=True)
+ parser.add_argument("--model", type=str, default="all-mpnet-base-v2")
+ # default="all-MiniLM-L12-v2")
+ # default="multi-qa-distilbert-cos-v1")
+ parser.add_argument("--batch-size", type=int, default=256)
+ parser.add_argument("--min-length", type=int)
+ parser.add_argument("--max-length", type=int)
+ parser.add_argument("--english-only", action="store_true")
+ parser.add_argument("--num-clusters", type=int, default=20)
+ parser.add_argument(
+ "--cluster-alg",
+ type=str,
+ choices=["kmeans", "aggcls", "HDBSCAN"],
+ default="kmeans",
+ )
+ parser.add_argument("--show-top-k", type=int, default=200)
+ parser.add_argument("--show-cut-off", type=int, default=512)
+ args = parser.parse_args()
+
+ num_clusters = args.num_clusters
+ show_top_k = args.show_top_k
+ show_cut_off = args.show_cut_off
+
+ texts = read_texts(
+ args.input_file, args.min_length, args.max_length, args.english_only
+ )
+ print(f"#text: {len(texts)}")
+
+ embeddings = get_embeddings(texts, args.model, args.batch_size)
+ if args.cluster_alg == "kmeans":
+ centers, labels = run_k_means(embeddings, num_clusters)
+ elif args.cluster_alg == "aggcls":
+ centers, labels = run_agg_cluster(embeddings, num_clusters)
+ elif args.cluster_alg == "HDBSCAN":
+ centers, labels = run_hdbscan_cluster(embeddings)
+ else:
+ raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}")
+
+ topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k)
+ topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off)
+ num_clusters = len(centers)
+
+ # Dump results
+ filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}"
+ print(topk_str)
+ with open(filename_prefix + "_topk.txt", "w") as fout:
+ fout.write(topk_str)
+
+ with open(filename_prefix + "_all.txt", "w") as fout:
+ for i in range(len(centers)):
+ tmp_indices = labels == i
+ tmp_embeddings = embeddings[tmp_indices]
+ tmp_texts = texts[tmp_indices]
+
+ scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0]
+ sorted_indices = torch.flip(torch.argsort(scores), dims=[0])
+
+ for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]):
+ obj = {"cluster": i, "text": text, "sim": score.item()}
+ fout.write(json.dumps(obj, ensure_ascii=False) + "\n")
+
+ cluster_info = get_cluster_info(texts, labels, topk_indices)
+ with open(filename_prefix + "_cluster.pkl", "wb") as fout:
+ pickle.dump(cluster_info, fout)
diff --git a/fastchat/serve/multi_model_worker.py b/fastchat/serve/multi_model_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..f77ff444790e51c7b4933765d6079fa41fe55ab1
--- /dev/null
+++ b/fastchat/serve/multi_model_worker.py
@@ -0,0 +1,282 @@
+"""
+A multi-model worker that contains multiple sub-works one for each model. This
+supports running a list of models on the same machine so that they can
+(potentially) share the same background weights.
+
+Each model can have one or more model names.
+
+This multi-model worker assumes the models shares some underlying weights and
+thus reports the combined queue lengths for health checks.
+
+We recommend using this with multiple Peft models (with `peft` in the name)
+where all Peft models are trained on the exact same base model.
+"""
+import argparse
+import asyncio
+import dataclasses
+import logging
+import json
+import os
+import time
+from typing import List, Union
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse, JSONResponse
+import requests
+
+try:
+ from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ LlamaTokenizer,
+ AutoModel,
+ )
+except ImportError:
+ from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ LLaMATokenizer,
+ AutoModel,
+ )
+import torch
+import torch.nn.functional as F
+import uvicorn
+
+from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
+from fastchat.model.model_adapter import (
+ load_model,
+ add_model_args,
+ get_conversation_template,
+)
+from fastchat.model.model_chatglm import generate_stream_chatglm
+from fastchat.model.model_falcon import generate_stream_falcon
+from fastchat.model.model_codet5p import generate_stream_codet5p
+from fastchat.modules.gptq import GptqConfig
+from fastchat.modules.exllama import ExllamaConfig
+from fastchat.modules.xfastertransformer import XftConfig
+from fastchat.serve.inference import generate_stream
+from fastchat.serve.model_worker import ModelWorker, worker_id, logger
+from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length
+
+
+# We store both the underlying workers and a mapping from their model names to
+# the worker instance. This makes it easy to fetch the appropriate worker for
+# each API call.
+workers = []
+worker_map = {}
+app = FastAPI()
+
+
+def release_worker_semaphore():
+ workers[0].semaphore.release()
+
+
+def acquire_worker_semaphore():
+ if workers[0].semaphore is None:
+ # Share the same semaphore for all workers because
+ # all workers share the same GPU.
+ semaphore = asyncio.Semaphore(workers[0].limit_worker_concurrency)
+ for w in workers:
+ w.semaphore = semaphore
+ return workers[0].semaphore.acquire()
+
+
+def create_background_tasks():
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(release_worker_semaphore)
+ return background_tasks
+
+
+# Note: for all the calls below, we make a hard assumption that the caller
+# includes the model name in the payload, otherwise we can't figure out which
+# underlying sub-worker to call.
+
+
+@app.post("/worker_generate_stream")
+async def api_generate_stream(request: Request):
+ params = await request.json()
+ await acquire_worker_semaphore()
+ worker = worker_map[params["model"]]
+ generator = worker.generate_stream_gate(params)
+ background_tasks = create_background_tasks()
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_generate")
+async def api_generate(request: Request):
+ params = await request.json()
+ await acquire_worker_semaphore()
+ worker = worker_map[params["model"]]
+ output = worker.generate_gate(params)
+ release_worker_semaphore()
+ return JSONResponse(output)
+
+
+@app.post("/worker_get_embeddings")
+async def api_get_embeddings(request: Request):
+ params = await request.json()
+ await acquire_worker_semaphore()
+ worker = worker_map[params["model"]]
+ embedding = worker.get_embeddings(params)
+ background_tasks = create_background_tasks()
+ return JSONResponse(content=embedding, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def api_get_status(request: Request):
+ return {
+ "model_names": [m for w in workers for m in w.model_names],
+ "speed": 1,
+ "queue_length": sum([w.get_queue_length() for w in workers]),
+ }
+
+
+@app.post("/count_token")
+async def api_count_token(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ return worker.count_token(params)
+
+
+@app.post("/worker_get_conv_template")
+async def api_get_conv(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ return worker.get_conv_template()
+
+
+@app.post("/model_details")
+async def api_model_details(request: Request):
+ params = await request.json()
+ worker = worker_map[params["model"]]
+ return {"context_length": worker.context_len}
+
+
+def create_multi_model_worker():
+ # Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST
+ # of the model args but we'll override one to have an append action that
+ # supports multiple values.
+ parser = argparse.ArgumentParser(conflict_handler="resolve")
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ add_model_args(parser)
+ # Override the model path to be repeated and align it with model names.
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default=[],
+ action="append",
+ help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.",
+ )
+ parser.add_argument(
+ "--model-names",
+ type=lambda s: s.split(","),
+ action="append",
+ help="One or more model names. Values must be aligned with `--model-path` values.",
+ )
+ parser.add_argument(
+ "--conv-template",
+ type=str,
+ default=None,
+ action="append",
+ help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.",
+ )
+ parser.add_argument("--limit-worker-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=2)
+ parser.add_argument("--no-register", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ if args.gpus:
+ if len(args.gpus.split(",")) < args.num_gpus:
+ raise ValueError(
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
+ )
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
+
+ gptq_config = GptqConfig(
+ ckpt=args.gptq_ckpt or args.model_path,
+ wbits=args.gptq_wbits,
+ groupsize=args.gptq_groupsize,
+ act_order=args.gptq_act_order,
+ )
+ if args.enable_exllama:
+ exllama_config = ExllamaConfig(
+ max_seq_len=args.exllama_max_seq_len,
+ gpu_split=args.exllama_gpu_split,
+ )
+ else:
+ exllama_config = None
+ if args.enable_xft:
+ xft_config = XftConfig(
+ max_seq_len=args.xft_max_seq_len,
+ data_type=args.xft_dtype,
+ )
+ if args.device != "cpu":
+ print("xFasterTransformer now is only support CPUs. Reset device to CPU")
+ args.device = "cpu"
+ else:
+ xft_config = None
+
+ if args.model_names is None:
+ args.model_names = [[x.split("/")[-1]] for x in args.model_path]
+
+ if args.conv_template is None:
+ args.conv_template = [None] * len(args.model_path)
+ elif len(args.conv_template) == 1: # Repeat the same template
+ args.conv_template = args.conv_template * len(args.model_path)
+
+ # Launch all workers
+ workers = []
+ for conv_template, model_path, model_names in zip(
+ args.conv_template, args.model_path, args.model_names
+ ):
+ w = ModelWorker(
+ args.controller_address,
+ args.worker_address,
+ worker_id,
+ model_path,
+ model_names,
+ args.limit_worker_concurrency,
+ args.no_register,
+ device=args.device,
+ num_gpus=args.num_gpus,
+ max_gpu_memory=args.max_gpu_memory,
+ load_8bit=args.load_8bit,
+ cpu_offloading=args.cpu_offloading,
+ gptq_config=gptq_config,
+ exllama_config=exllama_config,
+ xft_config=xft_config,
+ stream_interval=args.stream_interval,
+ conv_template=conv_template,
+ )
+ workers.append(w)
+ for model_name in model_names:
+ worker_map[model_name] = w
+
+ # Register all models
+ url = args.controller_address + "/register_worker"
+ data = {
+ "worker_name": workers[0].worker_addr,
+ "check_heart_beat": not args.no_register,
+ "worker_status": {
+ "model_names": [m for w in workers for m in w.model_names],
+ "speed": 1,
+ "queue_length": sum([w.get_queue_length() for w in workers]),
+ },
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ return args, workers
+
+
+if __name__ == "__main__":
+ args, workers = create_multi_model_worker()
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..c15527f4c12ee1c1261ae3ba005f6e1e530b483f
--- /dev/null
+++ b/fastchat/serve/openai_api_server.py
@@ -0,0 +1,879 @@
+"""A server that provides OpenAI-compatible RESTful APIs. It supports:
+
+- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
+- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions)
+- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings)
+
+Usage:
+python3 -m fastchat.serve.openai_api_server
+"""
+import asyncio
+import argparse
+import json
+import logging
+import os
+from typing import Generator, Optional, Union, Dict, List, Any
+
+import aiohttp
+import fastapi
+from fastapi import Depends, HTTPException
+from fastapi.exceptions import RequestValidationError
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse, JSONResponse
+from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
+import httpx
+from pydantic import BaseSettings
+import shortuuid
+import tiktoken
+import uvicorn
+
+from fastchat.constants import (
+ WORKER_API_TIMEOUT,
+ WORKER_API_EMBEDDING_BATCH_SIZE,
+ ErrorCode,
+)
+from fastchat.conversation import Conversation, SeparatorStyle
+from fastchat.protocol.openai_api_protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse,
+ ChatMessage,
+ ChatCompletionResponseChoice,
+ CompletionRequest,
+ CompletionResponse,
+ CompletionResponseChoice,
+ DeltaMessage,
+ CompletionResponseStreamChoice,
+ CompletionStreamResponse,
+ EmbeddingsRequest,
+ EmbeddingsResponse,
+ ErrorResponse,
+ LogProbs,
+ ModelCard,
+ ModelList,
+ ModelPermission,
+ UsageInfo,
+)
+from fastchat.protocol.api_protocol import (
+ APIChatCompletionRequest,
+ APITokenCheckRequest,
+ APITokenCheckResponse,
+ APITokenCheckResponseItem,
+)
+
+logger = logging.getLogger(__name__)
+
+conv_template_map = {}
+
+fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600)
+
+
+async def fetch_remote(url, pload=None, name=None):
+ async with aiohttp.ClientSession(timeout=fetch_timeout) as session:
+ async with session.post(url, json=pload) as response:
+ chunks = []
+ if response.status != 200:
+ ret = {
+ "text": f"{response.reason}",
+ "error_code": ErrorCode.INTERNAL_ERROR,
+ }
+ return json.dumps(ret)
+
+ async for chunk, _ in response.content.iter_chunks():
+ chunks.append(chunk)
+ output = b"".join(chunks)
+
+ if name is not None:
+ res = json.loads(output)
+ if name != "":
+ res = res[name]
+ return res
+
+ return output
+
+
+class AppSettings(BaseSettings):
+ # The address of the model controller.
+ controller_address: str = "http://localhost:21001"
+ api_keys: Optional[List[str]] = None
+
+
+app_settings = AppSettings()
+app = fastapi.FastAPI()
+headers = {"User-Agent": "FastChat API Server"}
+get_bearer_token = HTTPBearer(auto_error=False)
+
+
+async def check_api_key(
+ auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
+) -> str:
+ if app_settings.api_keys:
+ if auth is None or (token := auth.credentials) not in app_settings.api_keys:
+ raise HTTPException(
+ status_code=401,
+ detail={
+ "error": {
+ "message": "",
+ "type": "invalid_request_error",
+ "param": None,
+ "code": "invalid_api_key",
+ }
+ },
+ )
+ return token
+ else:
+ # api_keys not set; allow all
+ return None
+
+
+def create_error_response(code: int, message: str) -> JSONResponse:
+ return JSONResponse(
+ ErrorResponse(message=message, code=code).dict(), status_code=400
+ )
+
+
+@app.exception_handler(RequestValidationError)
+async def validation_exception_handler(request, exc):
+ return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
+
+
+async def check_model(request) -> Optional[JSONResponse]:
+ controller_address = app_settings.controller_address
+ ret = None
+
+ models = await fetch_remote(controller_address + "/list_models", None, "models")
+ if request.model not in models:
+ ret = create_error_response(
+ ErrorCode.INVALID_MODEL,
+ f"Only {'&&'.join(models)} allowed now, your model {request.model}",
+ )
+ return ret
+
+
+async def check_length(request, prompt, max_tokens, worker_addr):
+ if (
+ not isinstance(max_tokens, int) or max_tokens <= 0
+ ): # model worker not support max_tokens=None
+ max_tokens = 1024 * 1024
+
+ context_len = await fetch_remote(
+ worker_addr + "/model_details", {"model": request.model}, "context_length"
+ )
+ token_num = await fetch_remote(
+ worker_addr + "/count_token",
+ {"model": request.model, "prompt": prompt},
+ "count",
+ )
+ return min(max_tokens, context_len - token_num)
+
+
+def check_requests(request) -> Optional[JSONResponse]:
+ # Check all params
+ if request.max_tokens is not None and request.max_tokens <= 0:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
+ )
+ if request.n is not None and request.n <= 0:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.n} is less than the minimum of 1 - 'n'",
+ )
+ if request.temperature is not None and request.temperature < 0:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.temperature} is less than the minimum of 0 - 'temperature'",
+ )
+ if request.temperature is not None and request.temperature > 2:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
+ )
+ if request.top_p is not None and request.top_p < 0:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.top_p} is less than the minimum of 0 - 'top_p'",
+ )
+ if request.top_p is not None and request.top_p > 1:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
+ )
+ if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
+ )
+ if request.stop is not None and (
+ not isinstance(request.stop, str) and not isinstance(request.stop, list)
+ ):
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.stop} is not valid under any of the given schemas - 'stop'",
+ )
+
+ return None
+
+
+def process_input(model_name, inp):
+ if isinstance(inp, str):
+ inp = [inp]
+ elif isinstance(inp, list):
+ if isinstance(inp[0], int):
+ decoding = tiktoken.model.encoding_for_model(model_name)
+ inp = [decoding.decode(inp)]
+ elif isinstance(inp[0], list):
+ decoding = tiktoken.model.encoding_for_model(model_name)
+ inp = [decoding.decode(text) for text in inp]
+
+ return inp
+
+
+def create_openai_logprobs(logprob_dict):
+ """Create OpenAI-style logprobs."""
+ return LogProbs(**logprob_dict) if logprob_dict is not None else None
+
+
+def _add_to_set(s, new_stop):
+ if not s:
+ return
+ if isinstance(s, str):
+ new_stop.add(s)
+ else:
+ new_stop.update(s)
+
+
+async def get_gen_params(
+ model_name: str,
+ worker_addr: str,
+ messages: Union[str, List[Dict[str, str]]],
+ *,
+ temperature: float,
+ top_p: float,
+ top_k: Optional[int],
+ presence_penalty: Optional[float],
+ frequency_penalty: Optional[float],
+ max_tokens: Optional[int],
+ echo: Optional[bool],
+ logprobs: Optional[int] = None,
+ stop: Optional[Union[str, List[str]]],
+ best_of: Optional[int] = None,
+ use_beam_search: Optional[bool] = None,
+) -> Dict[str, Any]:
+ conv = await get_conv(model_name, worker_addr)
+ conv = Conversation(
+ name=conv["name"],
+ system_template=conv["system_template"],
+ system_message=conv["system_message"],
+ roles=conv["roles"],
+ messages=list(conv["messages"]), # prevent in-place modification
+ offset=conv["offset"],
+ sep_style=SeparatorStyle(conv["sep_style"]),
+ sep=conv["sep"],
+ sep2=conv["sep2"],
+ stop_str=conv["stop_str"],
+ stop_token_ids=conv["stop_token_ids"],
+ )
+
+ if isinstance(messages, str):
+ prompt = messages
+ else:
+ for message in messages:
+ msg_role = message["role"]
+ if msg_role == "system":
+ conv.set_system_message(message["content"])
+ elif msg_role == "user":
+ conv.append_message(conv.roles[0], message["content"])
+ elif msg_role == "assistant":
+ conv.append_message(conv.roles[1], message["content"])
+ else:
+ raise ValueError(f"Unknown role: {msg_role}")
+
+ # Add a blank message for the assistant.
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ gen_params = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": temperature,
+ "logprobs": logprobs,
+ "top_p": top_p,
+ "top_k": top_k,
+ "presence_penalty": presence_penalty,
+ "frequency_penalty": frequency_penalty,
+ "max_new_tokens": max_tokens,
+ "echo": echo,
+ "stop_token_ids": conv.stop_token_ids,
+ }
+
+ if best_of is not None:
+ gen_params.update({"best_of": best_of})
+ if use_beam_search is not None:
+ gen_params.update({"use_beam_search": use_beam_search})
+
+ new_stop = set()
+ _add_to_set(stop, new_stop)
+ _add_to_set(conv.stop_str, new_stop)
+
+ gen_params["stop"] = list(new_stop)
+
+ logger.debug(f"==== request ====\n{gen_params}")
+ return gen_params
+
+
+async def get_worker_address(model_name: str) -> str:
+ """
+ Get worker address based on the requested model
+
+ :param model_name: The worker's model name
+ :return: Worker address from the controller
+ :raises: :class:`ValueError`: No available worker for requested model
+ """
+ controller_address = app_settings.controller_address
+ worker_addr = await fetch_remote(
+ controller_address + "/get_worker_address", {"model": model_name}, "address"
+ )
+
+ # No available worker
+ if worker_addr == "":
+ raise ValueError(f"No available worker for {model_name}")
+ logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}")
+ return worker_addr
+
+
+async def get_conv(model_name: str, worker_addr: str):
+ conv_template = conv_template_map.get((worker_addr, model_name))
+ if conv_template is None:
+ conv_template = await fetch_remote(
+ worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv"
+ )
+ conv_template_map[(worker_addr, model_name)] = conv_template
+ return conv_template
+
+
+@app.get("/v1/models", dependencies=[Depends(check_api_key)])
+async def show_available_models():
+ controller_address = app_settings.controller_address
+ ret = await fetch_remote(controller_address + "/refresh_all_workers")
+ models = await fetch_remote(controller_address + "/list_models", None, "models")
+
+ models.sort()
+ # TODO: return real model permission details
+ model_cards = []
+ for m in models:
+ model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()]))
+ return ModelList(data=model_cards)
+
+
+@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
+async def create_chat_completion(request: ChatCompletionRequest):
+ """Creates a completion for the chat message"""
+ error_check_ret = await check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+ error_check_ret = check_requests(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ worker_addr = await get_worker_address(request.model)
+
+ gen_params = await get_gen_params(
+ request.model,
+ worker_addr,
+ request.messages,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ top_k=request.top_k,
+ presence_penalty=request.presence_penalty,
+ frequency_penalty=request.frequency_penalty,
+ max_tokens=request.max_tokens,
+ echo=False,
+ stop=request.stop,
+ )
+ gen_params["max_new_tokens"] = await check_length(
+ request,
+ gen_params["prompt"],
+ gen_params["max_new_tokens"],
+ worker_addr,
+ )
+
+ if request.stream:
+ generator = chat_completion_stream_generator(
+ request.model, gen_params, request.n, worker_addr
+ )
+ return StreamingResponse(generator, media_type="text/event-stream")
+
+ choices = []
+ chat_completions = []
+ for i in range(request.n):
+ content = asyncio.create_task(generate_completion(gen_params, worker_addr))
+ chat_completions.append(content)
+ try:
+ all_tasks = await asyncio.gather(*chat_completions)
+ except Exception as e:
+ return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
+ usage = UsageInfo()
+ for i, content in enumerate(all_tasks):
+ if content["error_code"] != 0:
+ return create_error_response(content["error_code"], content["text"])
+ choices.append(
+ ChatCompletionResponseChoice(
+ index=i,
+ message=ChatMessage(role="assistant", content=content["text"]),
+ finish_reason=content.get("finish_reason", "stop"),
+ )
+ )
+ if "usage" in content:
+ task_usage = UsageInfo.parse_obj(content["usage"])
+ for usage_key, usage_value in task_usage.dict().items():
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
+
+ return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
+
+
+async def chat_completion_stream_generator(
+ model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str
+) -> Generator[str, Any, None]:
+ """
+ Event stream format:
+ https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
+ """
+ id = f"chatcmpl-{shortuuid.random()}"
+ finish_stream_events = []
+ for i in range(n):
+ # First chunk with role
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(role="assistant"),
+ finish_reason=None,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=id, choices=[choice_data], model=model_name
+ )
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+
+ previous_text = ""
+ async for content in generate_completion_stream(gen_params, worker_addr):
+ if content["error_code"] != 0:
+ yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+ decoded_unicode = content["text"].replace("\ufffd", "")
+ delta_text = decoded_unicode[len(previous_text) :]
+ previous_text = (
+ decoded_unicode
+ if len(decoded_unicode) > len(previous_text)
+ else previous_text
+ )
+
+ if len(delta_text) == 0:
+ delta_text = None
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(content=delta_text),
+ finish_reason=content.get("finish_reason", None),
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=id, choices=[choice_data], model=model_name
+ )
+ if delta_text is None:
+ if content.get("finish_reason", None) is not None:
+ finish_stream_events.append(chunk)
+ continue
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ # There is not "content" field in the last delta message, so exclude_none to exclude field "content".
+ for finish_chunk in finish_stream_events:
+ yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+
+
+@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
+async def create_completion(request: CompletionRequest):
+ error_check_ret = await check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+ error_check_ret = check_requests(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ request.prompt = process_input(request.model, request.prompt)
+
+ worker_addr = await get_worker_address(request.model)
+ for text in request.prompt:
+ max_tokens = await check_length(request, text, request.max_tokens, worker_addr)
+ if isinstance(max_tokens, int) and max_tokens < request.max_tokens:
+ request.max_tokens = max_tokens
+
+ if request.stream:
+ generator = generate_completion_stream_generator(
+ request, request.n, worker_addr
+ )
+ return StreamingResponse(generator, media_type="text/event-stream")
+ else:
+ text_completions = []
+ for text in request.prompt:
+ gen_params = await get_gen_params(
+ request.model,
+ worker_addr,
+ text,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ top_k=request.top_k,
+ frequency_penalty=request.frequency_penalty,
+ presence_penalty=request.presence_penalty,
+ max_tokens=request.max_tokens,
+ logprobs=request.logprobs,
+ echo=request.echo,
+ stop=request.stop,
+ best_of=request.best_of,
+ use_beam_search=request.use_beam_search,
+ )
+ for i in range(request.n):
+ content = asyncio.create_task(
+ generate_completion(gen_params, worker_addr)
+ )
+ text_completions.append(content)
+
+ try:
+ all_tasks = await asyncio.gather(*text_completions)
+ except Exception as e:
+ return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
+
+ choices = []
+ usage = UsageInfo()
+ for i, content in enumerate(all_tasks):
+ if content["error_code"] != 0:
+ return create_error_response(content["error_code"], content["text"])
+ choices.append(
+ CompletionResponseChoice(
+ index=i,
+ text=content["text"],
+ logprobs=create_openai_logprobs(content.get("logprobs", None)),
+ finish_reason=content.get("finish_reason", "stop"),
+ )
+ )
+ task_usage = UsageInfo.parse_obj(content["usage"])
+ for usage_key, usage_value in task_usage.dict().items():
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
+
+ return CompletionResponse(
+ model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)
+ )
+
+
+async def generate_completion_stream_generator(
+ request: CompletionRequest, n: int, worker_addr: str
+):
+ model_name = request.model
+ id = f"cmpl-{shortuuid.random()}"
+ finish_stream_events = []
+ for text in request.prompt:
+ for i in range(n):
+ previous_text = ""
+ gen_params = await get_gen_params(
+ request.model,
+ worker_addr,
+ text,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ top_k=request.top_k,
+ presence_penalty=request.presence_penalty,
+ frequency_penalty=request.frequency_penalty,
+ max_tokens=request.max_tokens,
+ logprobs=request.logprobs,
+ echo=request.echo,
+ stop=request.stop,
+ )
+ async for content in generate_completion_stream(gen_params, worker_addr):
+ if content["error_code"] != 0:
+ yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+ decoded_unicode = content["text"].replace("\ufffd", "")
+ delta_text = decoded_unicode[len(previous_text) :]
+ previous_text = (
+ decoded_unicode
+ if len(decoded_unicode) > len(previous_text)
+ else previous_text
+ )
+ # todo: index is not apparent
+ choice_data = CompletionResponseStreamChoice(
+ index=i,
+ text=delta_text,
+ logprobs=create_openai_logprobs(content.get("logprobs", None)),
+ finish_reason=content.get("finish_reason", None),
+ )
+ chunk = CompletionStreamResponse(
+ id=id,
+ object="text_completion",
+ choices=[choice_data],
+ model=model_name,
+ )
+ if len(delta_text) == 0:
+ if content.get("finish_reason", None) is not None:
+ finish_stream_events.append(chunk)
+ continue
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ # There is not "content" field in the last delta message, so exclude_none to exclude field "content".
+ for finish_chunk in finish_stream_events:
+ yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+
+
+async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str):
+ controller_address = app_settings.controller_address
+ async with httpx.AsyncClient() as client:
+ delimiter = b"\0"
+ async with client.stream(
+ "POST",
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=payload,
+ timeout=WORKER_API_TIMEOUT,
+ ) as response:
+ # content = await response.aread()
+ buffer = b""
+ async for raw_chunk in response.aiter_raw():
+ buffer += raw_chunk
+ while (chunk_end := buffer.find(delimiter)) >= 0:
+ chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :]
+ if not chunk:
+ continue
+ yield json.loads(chunk.decode())
+
+
+async def generate_completion(payload: Dict[str, Any], worker_addr: str):
+ return await fetch_remote(worker_addr + "/worker_generate", payload, "")
+
+
+@app.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
+@app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)])
+async def create_embeddings(request: EmbeddingsRequest, model_name: str = None):
+ """Creates embeddings for the text"""
+ if request.model is None:
+ request.model = model_name
+ error_check_ret = await check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ request.input = process_input(request.model, request.input)
+
+ data = []
+ token_num = 0
+ batch_size = WORKER_API_EMBEDDING_BATCH_SIZE
+ batches = [
+ request.input[i : min(i + batch_size, len(request.input))]
+ for i in range(0, len(request.input), batch_size)
+ ]
+ for num_batch, batch in enumerate(batches):
+ payload = {
+ "model": request.model,
+ "input": batch,
+ "encoding_format": request.encoding_format,
+ }
+ embedding = await get_embedding(payload)
+ if "error_code" in embedding and embedding["error_code"] != 0:
+ return create_error_response(embedding["error_code"], embedding["text"])
+ data += [
+ {
+ "object": "embedding",
+ "embedding": emb,
+ "index": num_batch * batch_size + i,
+ }
+ for i, emb in enumerate(embedding["embedding"])
+ ]
+ token_num += embedding["token_num"]
+ return EmbeddingsResponse(
+ data=data,
+ model=request.model,
+ usage=UsageInfo(
+ prompt_tokens=token_num,
+ total_tokens=token_num,
+ completion_tokens=None,
+ ),
+ ).dict(exclude_none=True)
+
+
+async def get_embedding(payload: Dict[str, Any]):
+ controller_address = app_settings.controller_address
+ model_name = payload["model"]
+ worker_addr = await get_worker_address(model_name)
+
+ embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload)
+ return json.loads(embedding)
+
+
+### GENERAL API - NOT OPENAI COMPATIBLE ###
+
+
+@app.post("/api/v1/token_check")
+async def count_tokens(request: APITokenCheckRequest):
+ """
+ Checks the token count for each message in your list
+ This is not part of the OpenAI API spec.
+ """
+ checkedList = []
+ for item in request.prompts:
+ worker_addr = await get_worker_address(item.model)
+
+ context_len = await fetch_remote(
+ worker_addr + "/model_details",
+ {"prompt": item.prompt, "model": item.model},
+ "context_length",
+ )
+
+ token_num = await fetch_remote(
+ worker_addr + "/count_token",
+ {"prompt": item.prompt, "model": item.model},
+ "count",
+ )
+
+ can_fit = True
+ if token_num + item.max_tokens > context_len:
+ can_fit = False
+
+ checkedList.append(
+ APITokenCheckResponseItem(
+ fits=can_fit, contextLength=context_len, tokenCount=token_num
+ )
+ )
+
+ return APITokenCheckResponse(prompts=checkedList)
+
+
+@app.post("/api/v1/chat/completions")
+async def create_chat_completion(request: APIChatCompletionRequest):
+ """Creates a completion for the chat message"""
+ error_check_ret = await check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+ error_check_ret = check_requests(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ worker_addr = await get_worker_address(request.model)
+
+ gen_params = await get_gen_params(
+ request.model,
+ worker_addr,
+ request.messages,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ top_k=request.top_k,
+ presence_penalty=request.presence_penalty,
+ frequency_penalty=request.frequency_penalty,
+ max_tokens=request.max_tokens,
+ echo=False,
+ stop=request.stop,
+ )
+
+ if request.repetition_penalty is not None:
+ gen_params["repetition_penalty"] = request.repetition_penalty
+
+ gen_params["max_new_tokens"] = await check_length(
+ request,
+ gen_params["prompt"],
+ gen_params["max_new_tokens"],
+ worker_addr,
+ )
+
+ if request.stream:
+ generator = chat_completion_stream_generator(
+ request.model, gen_params, request.n, worker_addr
+ )
+ return StreamingResponse(generator, media_type="text/event-stream")
+
+ choices = []
+ chat_completions = []
+ for i in range(request.n):
+ content = asyncio.create_task(generate_completion(gen_params, worker_addr))
+ chat_completions.append(content)
+ try:
+ all_tasks = await asyncio.gather(*chat_completions)
+ except Exception as e:
+ return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
+ usage = UsageInfo()
+ for i, content in enumerate(all_tasks):
+ if content["error_code"] != 0:
+ return create_error_response(content["error_code"], content["text"])
+ choices.append(
+ ChatCompletionResponseChoice(
+ index=i,
+ message=ChatMessage(role="assistant", content=content["text"]),
+ finish_reason=content.get("finish_reason", "stop"),
+ )
+ )
+ task_usage = UsageInfo.parse_obj(content["usage"])
+ for usage_key, usage_value in task_usage.dict().items():
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
+
+ return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
+
+
+### END GENERAL API - NOT OPENAI COMPATIBLE ###
+
+
+def create_openai_api_server():
+ parser = argparse.ArgumentParser(
+ description="FastChat ChatGPT-Compatible RESTful API server."
+ )
+ parser.add_argument("--host", type=str, default="localhost", help="host name")
+ parser.add_argument("--port", type=int, default=8000, help="port number")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ parser.add_argument(
+ "--allow-credentials", action="store_true", help="allow credentials"
+ )
+ parser.add_argument(
+ "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
+ )
+ parser.add_argument(
+ "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
+ )
+ parser.add_argument(
+ "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
+ )
+ parser.add_argument(
+ "--api-keys",
+ type=lambda s: s.split(","),
+ help="Optional list of comma separated API keys",
+ )
+ parser.add_argument(
+ "--ssl",
+ action="store_true",
+ required=False,
+ default=False,
+ help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
+ )
+ args = parser.parse_args()
+
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=args.allowed_origins,
+ allow_credentials=args.allow_credentials,
+ allow_methods=args.allowed_methods,
+ allow_headers=args.allowed_headers,
+ )
+ app_settings.controller_address = args.controller_address
+ app_settings.api_keys = args.api_keys
+
+ logger.info(f"args: {args}")
+ return args
+
+
+if __name__ == "__main__":
+ args = create_openai_api_server()
+ if args.ssl:
+ uvicorn.run(
+ app,
+ host=args.host,
+ port=args.port,
+ log_level="info",
+ ssl_keyfile=os.environ["SSL_KEYFILE"],
+ ssl_certfile=os.environ["SSL_CERTFILE"],
+ )
+ else:
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/fastchat/serve/register_worker.py b/fastchat/serve/register_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2c40295e0351f25709ba25554c9329f15bf0d2
--- /dev/null
+++ b/fastchat/serve/register_worker.py
@@ -0,0 +1,26 @@
+"""
+Manually register workers.
+
+Usage:
+python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
+"""
+
+import argparse
+
+import requests
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str)
+ parser.add_argument("--worker-name", type=str)
+ parser.add_argument("--check-heart-beat", action="store_true")
+ args = parser.parse_args()
+
+ url = args.controller_address + "/register_worker"
+ data = {
+ "worker_name": args.worker_name,
+ "check_heart_beat": args.check_heart_beat,
+ "worker_status": None,
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
diff --git a/fastchat/serve/shutdown_serve.py b/fastchat/serve/shutdown_serve.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e2b704f0b65584c5be15ce14b40bc150bd6009
--- /dev/null
+++ b/fastchat/serve/shutdown_serve.py
@@ -0,0 +1,24 @@
+"""
+Usage:
+python shutdown_serve.py --down all
+options: "all","controller","model_worker","openai_api_server", `all` means to stop all related servers
+"""
+
+import argparse
+import os
+import subprocess
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--down", choices=["all", "controller", "model_worker", "openai_api_server"]
+)
+args = parser.parse_args()
+base_shell = "ps -eo user,pid,cmd|grep fastchat.serve{}|grep -v grep|awk '{{print $2}}'|xargs kill -9"
+if args.down == "all":
+ shell_script = base_shell.format("")
+else:
+ serve = f".{args.down}"
+ shell_script = base_shell.format(serve)
+print(f"execute shell cmd: {shell_script}")
+subprocess.run(shell_script, shell=True, check=True)
+print(f"{args.down} has been shutdown!")
diff --git a/fastchat/serve/test_message.py b/fastchat/serve/test_message.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d83f0fe7a6c06b285dc517d7b32e236d7867d88
--- /dev/null
+++ b/fastchat/serve/test_message.py
@@ -0,0 +1,82 @@
+"""Send a test message."""
+import argparse
+import json
+
+import requests
+
+from fastchat.model.model_adapter import get_conversation_template
+from fastchat.conversation import get_conv_template
+
+def main():
+ model_name = args.model_name
+ conv_template = args.conv_template
+ if args.worker_address:
+ worker_addr = args.worker_address
+ else:
+ controller_addr = args.controller_address
+ ret = requests.post(controller_addr + "/refresh_all_workers")
+ ret = requests.post(controller_addr + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ print(f"Models: {models}")
+
+ ret = requests.post(
+ controller_addr + "/get_worker_address", json={"model": model_name}
+ )
+ worker_addr = ret.json()["address"]
+ print(f"worker_addr: {worker_addr}")
+
+ if worker_addr == "":
+ print(f"No available workers for {model_name}")
+ return
+
+ # conv = get_conversation_template(model_name)
+ conv = get_conv_template(conv_template)
+ conv.append_message(conv.roles[0], args.message)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+ headers = {"User-Agent": "FastChat Client"}
+ gen_params = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": args.temperature,
+ "max_new_tokens": args.max_new_tokens,
+ "stop": conv.stop_str,
+ "stop_token_ids": conv.stop_token_ids,
+ "echo": False,
+ }
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=gen_params,
+ stream=True,
+ )
+
+ print(f"{conv.roles[0]}: {args.message}")
+ print(f"{conv.roles[1]}: ", end="")
+ prev = 0
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ output = data["text"].strip()
+ print(output[prev:], end="", flush=True)
+ prev = len(output)
+ print("")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ parser.add_argument("--worker-address", type=str)
+ parser.add_argument("--model-name", type=str, required=True)
+ parser.add_argument("--conv-template", type=str, required=True)
+ parser.add_argument("--temperature", type=float, default=0.0)
+ parser.add_argument("--max-new-tokens", type=int, default=32)
+ parser.add_argument(
+ "--message", type=str, default="Tell me a story with more than 1000 words."
+ )
+ args = parser.parse_args()
+
+ main()
diff --git a/fastchat/serve/test_throughput.py b/fastchat/serve/test_throughput.py
new file mode 100644
index 0000000000000000000000000000000000000000..3796a6e2a7cb53dc6921674fc4c488246e0b93c7
--- /dev/null
+++ b/fastchat/serve/test_throughput.py
@@ -0,0 +1,115 @@
+"""Benchmarking script to test the throughput of serving workers."""
+import argparse
+import json
+
+import requests
+import threading
+import time
+
+from fastchat.conversation import get_conv_template
+
+
+def main():
+ if args.worker_address:
+ worker_addr = args.worker_address
+ else:
+ controller_addr = args.controller_address
+ ret = requests.post(controller_addr + "/refresh_all_workers")
+ ret = requests.post(controller_addr + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ print(f"Models: {models}")
+
+ ret = requests.post(
+ controller_addr + "/get_worker_address", json={"model": args.model_name}
+ )
+ worker_addr = ret.json()["address"]
+ print(f"worker_addr: {worker_addr}")
+
+ if worker_addr == "":
+ return
+
+ conv = get_conv_template("vicuna_v1.1")
+ conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words")
+ prompt_template = conv.get_prompt()
+ prompts = [prompt_template for _ in range(args.n_thread)]
+
+ headers = {"User-Agent": "fastchat Client"}
+ ploads = [
+ {
+ "model": args.model_name,
+ "prompt": prompts[i],
+ "max_new_tokens": args.max_new_tokens,
+ "temperature": 0.0,
+ # "stop": conv.sep,
+ }
+ for i in range(len(prompts))
+ ]
+
+ def send_request(results, i):
+ if args.test_dispatch:
+ ret = requests.post(
+ controller_addr + "/get_worker_address", json={"model": args.model_name}
+ )
+ thread_worker_addr = ret.json()["address"]
+ else:
+ thread_worker_addr = worker_addr
+ print(f"thread {i} goes to {thread_worker_addr}")
+ response = requests.post(
+ thread_worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=ploads[i],
+ stream=False,
+ )
+ k = list(
+ response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")
+ )
+ # print(k)
+ response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
+ error_code = json.loads(k[-2].decode("utf-8"))["error_code"]
+ # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}")
+ results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" "))
+
+ # use N threads to prompt the backend
+ tik = time.time()
+ threads = []
+ results = [None] * args.n_thread
+ for i in range(args.n_thread):
+ t = threading.Thread(target=send_request, args=(results, i))
+ t.start()
+ # time.sleep(0.5)
+ threads.append(t)
+
+ for t in threads:
+ t.join()
+
+ print(f"Time (POST): {time.time() - tik} s")
+ # n_words = 0
+ # for i, response in enumerate(results):
+ # # print(prompt[i].replace(conv.sep, "\n"), end="")
+ # # make sure the streaming finishes at EOS or stopping criteria
+ # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
+ # response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
+ # # print(response_new_words)
+ # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
+ n_words = sum(results)
+ time_seconds = time.time() - tik
+ print(
+ f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
+ f"throughput: {n_words / time_seconds} words/s."
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ parser.add_argument("--worker-address", type=str)
+ parser.add_argument("--model-name", type=str, default="vicuna")
+ parser.add_argument("--max-new-tokens", type=int, default=2048)
+ parser.add_argument("--n-thread", type=int, default=8)
+ parser.add_argument("--test-dispatch", action="store_true")
+ args = parser.parse_args()
+
+ main()
diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..6428d8b442f188add25b5db9ff04c86bc9bfc7fd
--- /dev/null
+++ b/fastchat/serve/vllm_worker.py
@@ -0,0 +1,271 @@
+"""
+A model worker that executes the model based on vLLM.
+
+See documentations at docs/vllm_integration.md
+"""
+
+import argparse
+import asyncio
+import json
+from typing import List
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse, JSONResponse
+import uvicorn
+from vllm import AsyncLLMEngine
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.sampling_params import SamplingParams
+from vllm.utils import random_uuid
+
+from fastchat.serve.base_model_worker import BaseModelWorker
+from fastchat.serve.model_worker import (
+ logger,
+ worker_id,
+)
+from fastchat.utils import get_context_length
+
+
+app = FastAPI()
+
+
+class VLLMWorker(BaseModelWorker):
+ def __init__(
+ self,
+ controller_addr: str,
+ worker_addr: str,
+ worker_id: str,
+ model_path: str,
+ model_names: List[str],
+ limit_worker_concurrency: int,
+ no_register: bool,
+ llm_engine: AsyncLLMEngine,
+ conv_template: str,
+ ):
+ super().__init__(
+ controller_addr,
+ worker_addr,
+ worker_id,
+ model_path,
+ model_names,
+ limit_worker_concurrency,
+ conv_template,
+ )
+
+ logger.info(
+ f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..."
+ )
+ self.tokenizer = llm_engine.engine.tokenizer
+ self.context_len = get_context_length(llm_engine.engine.model_config.hf_config)
+
+ if not no_register:
+ self.init_heart_beat()
+
+ async def generate_stream(self, params):
+ self.call_ct += 1
+
+ context = params.pop("prompt")
+ request_id = params.pop("request_id")
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ top_k = params.get("top_k", -1.0)
+ presence_penalty = float(params.get("presence_penalty", 0.0))
+ frequency_penalty = float(params.get("frequency_penalty", 0.0))
+ max_new_tokens = params.get("max_new_tokens", 256)
+ stop_str = params.get("stop", None)
+ stop_token_ids = params.get("stop_token_ids", None) or []
+ if self.tokenizer.eos_token_id is not None:
+ stop_token_ids.append(self.tokenizer.eos_token_id)
+ echo = params.get("echo", True)
+ use_beam_search = params.get("use_beam_search", False)
+ best_of = params.get("best_of", None)
+
+ # Handle stop_str
+ stop = set()
+ if isinstance(stop_str, str) and stop_str != "":
+ stop.add(stop_str)
+ elif isinstance(stop_str, list) and stop_str != []:
+ stop.update(stop_str)
+
+ for tid in stop_token_ids:
+ if tid is not None:
+ stop.add(self.tokenizer.decode(tid))
+
+ # make sampling params in vllm
+ top_p = max(top_p, 1e-5)
+ if temperature <= 1e-5:
+ top_p = 1.0
+
+ sampling_params = SamplingParams(
+ n=1,
+ temperature=temperature,
+ top_p=top_p,
+ use_beam_search=use_beam_search,
+ stop=list(stop),
+ max_tokens=max_new_tokens,
+ top_k=top_k,
+ presence_penalty=presence_penalty,
+ frequency_penalty=frequency_penalty,
+ best_of=best_of,
+ )
+ results_generator = engine.generate(context, sampling_params, request_id)
+
+ async for request_output in results_generator:
+ prompt = request_output.prompt
+ if echo:
+ text_outputs = [
+ prompt + output.text for output in request_output.outputs
+ ]
+ else:
+ text_outputs = [output.text for output in request_output.outputs]
+ text_outputs = " ".join(text_outputs)
+ # Note: usage is not supported yet
+ prompt_tokens = len(request_output.prompt_token_ids)
+ completion_tokens = sum(
+ len(output.token_ids) for output in request_output.outputs
+ )
+ ret = {
+ "text": text_outputs,
+ "error_code": 0,
+ "usage": {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": completion_tokens,
+ "total_tokens": prompt_tokens + completion_tokens,
+ },
+ "cumulative_logprob": [
+ output.cumulative_logprob for output in request_output.outputs
+ ],
+ "finish_reason": request_output.outputs[0].finish_reason
+ if len(request_output.outputs) == 1
+ else [output.finish_reason for output in request_output.outputs],
+ }
+ yield (json.dumps(ret) + "\0").encode()
+
+ async def generate(self, params):
+ async for x in self.generate_stream(params):
+ pass
+ return json.loads(x[:-1].decode())
+
+
+def release_worker_semaphore():
+ worker.semaphore.release()
+
+
+def acquire_worker_semaphore():
+ if worker.semaphore is None:
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
+ return worker.semaphore.acquire()
+
+
+def create_background_tasks(request_id):
+ async def abort_request() -> None:
+ await engine.abort(request_id)
+
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(release_worker_semaphore)
+ background_tasks.add_task(abort_request)
+ return background_tasks
+
+
+@app.post("/worker_generate_stream")
+async def api_generate_stream(request: Request):
+ params = await request.json()
+ await acquire_worker_semaphore()
+ request_id = random_uuid()
+ params["request_id"] = request_id
+ generator = worker.generate_stream(params)
+ background_tasks = create_background_tasks(request_id)
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_generate")
+async def api_generate(request: Request):
+ params = await request.json()
+ await acquire_worker_semaphore()
+ request_id = random_uuid()
+ params["request_id"] = request_id
+ output = await worker.generate(params)
+ release_worker_semaphore()
+ await engine.abort(request_id)
+ return JSONResponse(output)
+
+
+@app.post("/worker_get_status")
+async def api_get_status(request: Request):
+ return worker.get_status()
+
+
+@app.post("/count_token")
+async def api_count_token(request: Request):
+ params = await request.json()
+ return worker.count_token(params)
+
+
+@app.post("/worker_get_conv_template")
+async def api_get_conv(request: Request):
+ return worker.get_conv_template()
+
+
+@app.post("/model_details")
+async def api_model_details(request: Request):
+ return {"context_length": worker.context_len}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5")
+ parser.add_argument(
+ "--model-names",
+ type=lambda s: s.split(","),
+ help="Optional display comma separated names",
+ )
+ parser.add_argument("--limit-worker-concurrency", type=int, default=1024)
+ parser.add_argument("--no-register", action="store_true")
+ parser.add_argument("--num-gpus", type=int, default=1)
+ parser.add_argument(
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
+ )
+ parser.add_argument(
+ "--trust_remote_code",
+ action="store_false",
+ default=True,
+ help="Trust remote code (e.g., from HuggingFace) when"
+ "downloading the model and tokenizer.",
+ )
+ parser.add_argument(
+ "--gpu_memory_utilization",
+ type=float,
+ default=0.9,
+ help="The ratio (between 0 and 1) of GPU memory to"
+ "reserve for the model weights, activations, and KV cache. Higher"
+ "values will increase the KV cache size and thus improve the model's"
+ "throughput. However, if the value is too high, it may cause out-of-"
+ "memory (OOM) errors.",
+ )
+
+ parser = AsyncEngineArgs.add_cli_args(parser)
+ args = parser.parse_args()
+ if args.model_path:
+ args.model = args.model_path
+ if args.num_gpus > 1:
+ args.tensor_parallel_size = args.num_gpus
+
+ engine_args = AsyncEngineArgs.from_cli_args(args)
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
+ worker = VLLMWorker(
+ args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.model_path,
+ args.model_names,
+ args.limit_worker_concurrency,
+ args.no_register,
+ engine,
+ args.conv_template,
+ )
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/fastchat/train/llama2_flash_attn_monkey_patch.py b/fastchat/train/llama2_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1fe51c91bd553f8fbe0c25e9c88fe1abc3542e7
--- /dev/null
+++ b/fastchat/train/llama2_flash_attn_monkey_patch.py
@@ -0,0 +1,238 @@
+import warnings
+from typing import Optional, Tuple
+
+import torch
+from flash_attn import __version__ as flash_attn_version
+from flash_attn.bert_padding import pad_input, unpad_input
+from flash_attn.flash_attn_interface import (
+ flash_attn_func,
+ flash_attn_varlen_kvpacked_func,
+)
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaModel,
+ rotate_half,
+)
+
+
+def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
+ gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
+ gather_indices = gather_indices.repeat(
+ 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
+ )
+ bsz = gather_indices.shape[0]
+ cos, sin = (
+ torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
+ for x in cos_sin
+ )
+ q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
+ return q, k
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn(
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+ kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
+
+ q, k, v = (
+ op(hidden_states).view(bsz, q_len, nh, self.head_dim)
+ for op, nh in (
+ (self.q_proj, self.num_heads),
+ (self.k_proj, kv_heads),
+ (self.v_proj, kv_heads),
+ )
+ )
+ # shape: (b, s, num_heads, head_dim)
+
+ kv_seq_len = k.shape[1]
+ past_kv_len = 0
+ if past_key_value is not None:
+ past_kv_len = past_key_value[0].shape[2]
+ kv_seq_len += past_kv_len
+
+ cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
+ q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
+
+ if past_key_value is not None:
+ assert (
+ flash_attn_version >= "2.1.0"
+ ), "past_key_value support requires flash-attn >= 2.1.0"
+ # reuse k, v
+ k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
+ v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
+
+ past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
+
+ if attention_mask is None:
+ output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
+ bsz, q_len, -1
+ )
+ else:
+ q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
+ # We can skip concat and call unpad twice but seems better to call unpad only once.
+ kv, _, cu_k_lens, max_k = unpad_input(
+ torch.stack((k, v), dim=2), attention_mask
+ )
+ output_unpad = flash_attn_varlen_kvpacked_func(
+ q,
+ kv,
+ cu_q_lens,
+ cu_k_lens,
+ max_s,
+ max_k,
+ 0.0,
+ softmax_scale=None,
+ causal=True,
+ )
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as flash attention
+# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ if past_key_values_length > 0 and attention_mask is not None:
+ attention_mask = torch.cat(
+ (
+ torch.full(
+ (input_shape[0], past_key_values_length),
+ True,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ ),
+ attention_mask,
+ ),
+ dim=-1,
+ )
+
+ if attention_mask is not None and torch.all(attention_mask):
+ return None # This uses the faster call when training with full samples
+
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn(
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+ )
+
+ LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
+ LlamaAttention.forward = forward
+
+
+def test():
+ from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
+ from transformers.models.llama.configuration_llama import LlamaConfig
+
+ config = LlamaConfig(
+ hidden_size=1024,
+ intermediate_size=128,
+ num_hidden_layers=1,
+ num_attention_heads=8,
+ max_position_embeddings=16,
+ )
+ device = torch.device("cuda")
+ model = LlamaModel(config)
+ attn = LlamaAttention(config).to(device).half()
+ bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
+ -1, seqlen
+ )
+
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
+ for i in range(4):
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
+ if i:
+ mask[0, -i:] = False
+ mask[1, :i] = False
+
+ lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
+ ref, _, _ = attn.forward(
+ hidden, attention_mask=lmask, position_ids=position_ids
+ )
+
+ fast, _, _ = fastchat_forward(
+ attn, hidden, attention_mask=mask, position_ids=position_ids
+ )
+
+ lmask = _prepare_decoder_attention_mask(
+ model, mask, hidden.shape[:2], hidden, 0
+ )
+ test, _, _ = forward(
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
+ )
+
+ print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
+ print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
+ print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
+ print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
+ print(f"allclose(fast, test) = {torch.allclose(fast, test)}")
+
+ with torch.no_grad():
+ # Also check that past_kv is handled properly
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
+ part_len = seqlen // 4
+ assert part_len * 4 == seqlen
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
+ mask[0, -2:] = False
+ lmask = _prepare_decoder_attention_mask(
+ model, mask, hidden.shape[:2], hidden, 0
+ )
+ oneshot, _, _ = forward(
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
+ )
+ parts = []
+ past_kv, past_kv_len = None, 0
+ for i in range(4):
+ start = part_len * i
+ end = start + part_len
+ hidden_part = hidden[:, start:end, ...]
+ lmask = _prepare_decoder_attention_mask(
+ model,
+ mask[:, start:end],
+ hidden_part.shape[:2],
+ hidden_part,
+ past_kv_len,
+ )
+ part, _, past_kv = forward(
+ attn,
+ hidden_part.clone(),
+ attention_mask=lmask,
+ position_ids=position_ids[:, start:end],
+ past_key_value=past_kv,
+ use_cache=True,
+ )
+ parts.append(part)
+ past_kv_len = past_kv[0].shape[2]
+
+ print(
+ f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
+ )
+ print(
+ f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
+ )
+
+
+if __name__ == "__main__":
+ test()
diff --git a/fastchat/train/llama_flash_attn_monkey_patch.py b/fastchat/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b64aa8181726c26c9b3da355e17a6afb163f7796
--- /dev/null
+++ b/fastchat/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,107 @@
+from typing import Optional, Tuple
+import warnings
+
+import torch
+from torch import nn
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+
+from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn(
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ ) # shape: (b, num_heads, s, head_dim)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ if past_key_value is not None:
+ # reuse k, v
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # Transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ max_s = q_len
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = output.view(bsz, q_len, -1)
+ else:
+ qkv = qkv.reshape(bsz, q_len, -1)
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ output_unpad = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn(
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+ )
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/fastchat/train/llama_xformers_attn_monkey_patch.py b/fastchat/train/llama_xformers_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8351e41ccd4a64dca237bd8f8be0702b23989dc
--- /dev/null
+++ b/fastchat/train/llama_xformers_attn_monkey_patch.py
@@ -0,0 +1,129 @@
+"""
+Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
+"""
+
+import logging
+import math
+from typing import Optional, Tuple
+
+import torch
+import transformers.models.llama.modeling_llama
+from torch import nn
+
+try:
+ import xformers.ops
+except ImportError:
+ logging.error("xformers not found! Please install it before trying to use it.")
+
+
+def replace_llama_attn_with_xformers_attn():
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
+
+
+def xformers_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # pylint: disable=duplicate-code
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ (
+ query_states,
+ key_states,
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # We only apply xformers optimizations if we don't need to output the whole attention matrix
+ if not output_attentions:
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(
+ query_states, key_states, value_states, attn_bias=None
+ )
+ else:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_bias=xformers.ops.LowerTriangularMask(),
+ )
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
+ )
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
diff --git a/fastchat/train/train.py b/fastchat/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..95b5354a94eb369a9e2b2d433627096d6a009350
--- /dev/null
+++ b/fastchat/train/train.py
@@ -0,0 +1,301 @@
+# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
+#
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+import json
+import math
+import pathlib
+from typing import Dict, Optional, Sequence
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+import transformers
+from transformers import Trainer
+from transformers.trainer_pt_utils import LabelSmoother
+
+from fastchat.conversation import SeparatorStyle
+from fastchat.model.model_adapter import get_conversation_template
+
+IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(
+ default=None, metadata={"help": "Path to the training data."}
+ )
+ eval_data_path: str = field(
+ default=None, metadata={"help": "Path to the evaluation data."}
+ )
+ lazy_preprocess: bool = False
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+
+
+local_rank = None
+
+
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+def trainer_save_model_safe(trainer: transformers.Trainer):
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp import StateDictType, FullStateDictConfig
+
+ save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ with FSDP.state_dict_type(
+ trainer.model, StateDictType.FULL_STATE_DICT, save_policy
+ ):
+ trainer.save_model()
+
+
+def preprocess(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ conv = get_conversation_template("vicuna")
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
+
+ # Mask targets. Only compute loss on the assistant outputs.
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ turns = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_TOKEN_ID
+ for i, turn in enumerate(turns):
+ if turn == "":
+ break
+ turn_len = len(tokenizer(turn).input_ids)
+
+ parts = turn.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ if i != 0 and not tokenizer.legacy:
+ # The legacy and non-legacy modes handle special tokens differently
+ instruction_len -= 1
+
+ # Ignore the user instructions
+ target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
+ cur_len += turn_len
+
+ if i != 0 and not tokenizer.legacy:
+ # The legacy and non-legacy modes handle special tokens differently
+ cur_len -= 1
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if False: # Inspect and check the correctness of masking
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ rank0_print(tokenizer.decode(z))
+ exit()
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ rank0_print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" #turn = {len(turns) - 1}. (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+class SupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
+ super(SupervisedDataset, self).__init__()
+
+ rank0_print("Formatting inputs...")
+ sources = [example["conversations"] for example in raw_data]
+ data_dict = preprocess(sources, tokenizer)
+
+ self.input_ids = data_dict["input_ids"]
+ self.labels = data_dict["labels"]
+ self.attention_mask = data_dict["attention_mask"]
+
+ def __len__(self):
+ return len(self.input_ids)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ return dict(
+ input_ids=self.input_ids[i],
+ labels=self.labels[i],
+ attention_mask=self.attention_mask[i],
+ )
+
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
+ super(LazySupervisedDataset, self).__init__()
+ self.tokenizer = tokenizer
+
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.raw_data = raw_data
+ self.cached_data_dict = {}
+
+ def __len__(self):
+ return len(self.raw_data)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ if i in self.cached_data_dict:
+ return self.cached_data_dict[i]
+
+ ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
+ ret = dict(
+ input_ids=ret["input_ids"][0],
+ labels=ret["labels"][0],
+ attention_mask=ret["attention_mask"][0],
+ )
+ self.cached_data_dict[i] = ret
+
+ return ret
+
+
+def make_supervised_data_module(
+ tokenizer: transformers.PreTrainedTokenizer, data_args
+) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ dataset_cls = (
+ LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
+ )
+ rank0_print("Loading data...")
+
+ train_json = json.load(open(data_args.data_path, "r"))
+ train_dataset = dataset_cls(train_json, tokenizer=tokenizer)
+
+ if data_args.eval_data_path:
+ eval_json = json.load(open(data_args.eval_data_path, "r"))
+ eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
+ else:
+ eval_dataset = None
+
+ return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
+
+
+def train():
+ global local_rank
+
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments)
+ )
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ local_rank = training_args.local_rank
+
+ # Set RoPE scaling factor
+ config = transformers.AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ )
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
+ if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
+ scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
+ config.use_cache = False
+
+ # Load model and tokenizer
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ cache_dir=training_args.cache_dir,
+ )
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+ tokenizer.pad_token = tokenizer.unk_token
+
+ # Load data
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
+
+ # Start trainner
+ trainer = Trainer(
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
+ )
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+
+ # Save model
+ model.config.use_cache = True
+ trainer.save_state()
+ trainer_save_model_safe(trainer)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/fastchat/train/train_baichuan.py b/fastchat/train/train_baichuan.py
new file mode 100644
index 0000000000000000000000000000000000000000..70c6488b5deabc8b13059258855c27e0e7c267ab
--- /dev/null
+++ b/fastchat/train/train_baichuan.py
@@ -0,0 +1,333 @@
+# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
+#
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+import json
+import math
+import jsonlines
+import pathlib
+from multiprocessing import Pool
+from typing import Dict, Optional, Sequence
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+import transformers
+from transformers import Trainer
+from transformers.trainer_pt_utils import LabelSmoother
+
+from fastchat.conversation import SeparatorStyle
+from fastchat.model.model_adapter import get_conversation_template
+
+IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(
+ default=None, metadata={"help": "Path to the training data."}
+ )
+ lazy_preprocess: bool = False
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+
+
+local_rank = None
+
+
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def apply_prompt_template(sources, systems=None):
+ conv = get_conversation_template("vicuna")
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ if systems and systems[i]:
+ conv.set_system_message(systems[i])
+ prompt = conv.get_prompt()
+ conversations.append(prompt)
+ return conversations, conv
+
+
+def tokenize_conversations(conversations, tokenizer):
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+ return input_ids, targets
+
+
+def mask_targets(conversations, targets, tokenizer, conv):
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ turns = conversation.split(conv.sep2)
+ cur_len = 0
+ target[:cur_len] = IGNORE_TOKEN_ID
+ for i, turn in enumerate(turns):
+ if turn == "":
+ break
+ turn_len = len(tokenizer(turn + conv.sep2).input_ids)
+
+ parts = turn.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
+ cur_len += turn_len
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if False: # Inspect and check the correctness of masking
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ rank0_print(tokenizer.decode(z))
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ rank0_print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+ return targets
+
+
+def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer, **kwargs) -> Dict:
+ systems = None if not kwargs else kwargs.get("systems", None)
+
+ # If the data volume is small, process it directly in the main thread
+ if len(sources) <= 1000:
+ conversations, conv = apply_prompt_template(sources, systems)
+ input_ids, targets = tokenize_conversations(conversations, tokenizer)
+ targets = mask_targets(conversations, targets, tokenizer, conv)
+ else: # If the data volume is large, use multithreading for processing
+ with Pool() as p:
+ conversations, conv = p.apply_async(
+ apply_prompt_template, (sources, tokenizer, systems)
+ ).get()
+ input_ids, targets = p.apply_async(
+ tokenize_conversations, (conversations, tokenizer)
+ ).get()
+ targets = p.apply_async(
+ mask_targets, (conversations, targets, tokenizer, conv)
+ ).get()
+ p.close()
+ p.join()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+class SupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
+ super(SupervisedDataset, self).__init__()
+
+ rank0_print("Formatting inputs...")
+ systems = [example.get("system", "") for example in raw_data]
+ sources = [example["conversations"] for example in raw_data]
+
+ data_dict = preprocess(sources, tokenizer, systems=systems)
+
+ self.input_ids = data_dict["input_ids"]
+ self.labels = data_dict["labels"]
+ self.attention_mask = data_dict["attention_mask"]
+
+ def __len__(self):
+ return len(self.input_ids)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ return dict(
+ input_ids=self.input_ids[i],
+ labels=self.labels[i],
+ attention_mask=self.attention_mask[i],
+ )
+
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
+ super(LazySupervisedDataset, self).__init__()
+ self.tokenizer = tokenizer
+
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.raw_data = raw_data
+ self.cached_data_dict = {}
+
+ def __len__(self):
+ return len(self.raw_data)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ if i in self.cached_data_dict:
+ return self.cached_data_dict[i]
+
+ ret = preprocess(
+ [self.raw_data[i]["conversations"]],
+ self.tokenizer,
+ systems=[self.raw_data[i].get("system", "")],
+ )
+ ret = dict(
+ input_ids=ret["input_ids"][0],
+ labels=ret["labels"][0],
+ attention_mask=ret["attention_mask"][0],
+ )
+ self.cached_data_dict[i] = ret
+
+ return ret
+
+
+def make_supervised_data_module(
+ tokenizer: transformers.PreTrainedTokenizer, data_args, train_ratio=0.98
+) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_ratio = min(train_ratio, 1.0)
+ dataset_cls = (
+ LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
+ )
+ rank0_print("Loading data...")
+ data_path = data_args.data_path
+ if data_path.endswith(".json"):
+ raw_data = json.load(open(data_path, "r"))
+ elif data_path.endswith(".jsonl"):
+ with jsonlines.open(data_path, mode="r") as reader:
+ raw_data = [item for item in reader]
+
+ # Split train/test
+ np.random.seed(0)
+ perm = np.random.permutation(len(raw_data))
+ split = int(len(perm) * train_ratio)
+ train_indices = perm[:split]
+ if train_ratio < 1:
+ eval_indices = perm[split:]
+ else:
+ # if train_ratio==1, we use 5% of data as eval data, make sure trainer will not throw error when eval data is empty
+ eval_indices = perm[-int(len(perm) * 0.05) :]
+ train_raw_data = [raw_data[i] for i in train_indices]
+ eval_raw_data = [raw_data[i] for i in eval_indices]
+ rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}")
+
+ train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer)
+ eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
+
+
+def train():
+ global local_rank
+
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments)
+ )
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ local_rank = training_args.local_rank
+ config = transformers.AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ trust_remote_code=True,
+ cache_dir=training_args.cache_dir,
+ )
+ # Set RoPE scaling factor
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
+ if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
+ scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
+ config.use_cache = False
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ trust_remote_code=True,
+ cache_dir=training_args.cache_dir,
+ )
+ # Tie the weights
+ model.tie_weights()
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ trust_remote_code=True,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+ # NOTE: if the token_id exceed the vocab_size will cause failing in training process! we need add special config and resize the embedding size!
+ tokenizer.pad_token = tokenizer.unk_token
+ print(f"tokens len: {len(tokenizer)}")
+ model.resize_token_embeddings(len(tokenizer))
+
+ data_module = make_supervised_data_module(
+ tokenizer=tokenizer, train_ratio=0.98, data_args=data_args
+ )
+ trainer = Trainer(
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
+ )
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/fastchat/train/train_flant5.py b/fastchat/train/train_flant5.py
new file mode 100644
index 0000000000000000000000000000000000000000..688c2f4fa33ec50b5daab43b62e984b2aced1c68
--- /dev/null
+++ b/fastchat/train/train_flant5.py
@@ -0,0 +1,436 @@
+# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+import copy
+import os
+from dataclasses import dataclass, field
+import random
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence
+
+import torch
+import torch.distributed as dist
+
+import transformers
+from torch.utils.data import Dataset
+from transformers import Trainer, AddedToken
+
+from fastchat.model.model_adapter import get_conversation_template
+
+default_conversation = get_conversation_template("t5")
+
+# TODO: import and use code from ../data/dataset.py
+
+IGNORE_INDEX = -100
+DEFAULT_PAD_TOKEN = "[PAD]"
+DEFAULT_EOS_TOKEN = ""
+DEFAULT_BOS_TOKEN = ""
+DEFAULT_UNK_TOKEN = ""
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(
+ default=None, metadata={"help": "Path to the training data."}
+ )
+ lazy_preprocess: bool = False
+ num_data: int = -1
+ preprocessed_path: str = field(
+ default=None, metadata={"help": "Path to the preprocessed training data."}
+ )
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ model_max_length: int = field(
+ default=2048,
+ metadata={
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ other_tokens,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ for new_token in other_tokens:
+ num_new_tokens += tokenizer.add_tokens(AddedToken(new_token, normalized=False))
+
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True
+ )
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True
+ )
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(
+ strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
+) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ )
+ for text in strings
+ ]
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _form_qa(
+ q_list,
+ a_list,
+ tokenized_conversation,
+ tokenized_lens,
+ speakers,
+ header_len,
+ max_length,
+ eos_id,
+):
+ cur_idx = header_len
+ conv_len = len(tokenized_conversation)
+
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if cur_idx >= conv_len:
+ break
+ if speaker == "gpt":
+ # truncate answer if it is too long
+ content_a = None
+ if tokenized_len > max_length:
+ content_a = tokenized_conversation[cur_idx : cur_idx + max_length]
+ else:
+ content_a = tokenized_conversation[cur_idx : cur_idx + tokenized_len]
+ content_a.append(eos_id)
+ a_list.append(content_a)
+ content_q = None
+ if cur_idx >= max_length:
+ content_q = tokenized_conversation[cur_idx - max_length : cur_idx]
+ else:
+ content_q = tokenized_conversation[:cur_idx]
+ content_q.append(eos_id)
+ q_list.append(content_q)
+ # asser the last token is actually a EOS for an answer
+ assert a_list[-1][-1] == eos_id, "Last Token is not EOS!"
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+
+ unknown_role = "unknown" # use default unknown role
+ roles = {
+ "human": default_conversation.roles[0], # human role
+ "gpt": default_conversation.roles[1], # gpt role
+ }
+
+ for i in range(len(source)):
+ sentence = source[i]
+ sentence_from = sentence["from"].lower()
+
+ # TODO(Dacheng): verify this is a good way to split sentences
+ if sentence_from == "human":
+ # if this is not the last sentence
+ if i != len(source) - 1:
+ next_sentence = source[i + 1]
+ sentence["value"] = (
+ BEGIN_SIGNAL
+ + roles.get(sentence_from, unknown_role)
+ + ": "
+ + sentence["value"]
+ + END_SIGNAL
+ + BEGIN_SIGNAL
+ + roles.get(next_sentence["from"].lower(), unknown_role)
+ + ": "
+ )
+ else:
+ # if human is the last speaker, it does not contribute to an answer
+ pass
+ else:
+ sentence["value"] = sentence["value"] + END_SIGNAL
+ if get_conversation:
+ conversation += sentence["value"]
+
+ return conversation
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ # add end signal and concatenate together
+ conversations = []
+ header = f"{default_conversation.system_message}\n\n"
+ for source in sources:
+ conversation = _add_speaker_and_signal(header, source, tokenizer)
+ conversations.append(conversation)
+ # TODO(Dacheng): This is related to whether the dataset has been truncated..
+ # Assume we get long conversations, don't pad, don't return tensor
+ tokenized_conversations = tokenizer(conversations, max_length=None)["input_ids"]
+ q_list = []
+ a_list = []
+ # count for EOS length
+ header_len = _tokenize_fn([header], tokenizer)["input_ids_lens"][0] - 1
+ from tqdm import tqdm
+
+ for tokenized_conversation, source in tqdm(zip(tokenized_conversations, sources)):
+ tokenized_sentence = _tokenize_fn([s["value"] for s in source], tokenizer)
+ tokenized_lens = tokenized_sentence["input_ids_lens"]
+ tokenized_lens = [l - 1 for l in tokenized_lens]
+ speakers = [sentence["from"] for sentence in source]
+ ids = tokenized_sentence["input_ids"]
+ _form_qa(
+ q_list,
+ a_list,
+ tokenized_conversation,
+ tokenized_lens,
+ speakers,
+ header_len,
+ tokenizer.model_max_length,
+ tokenizer.eos_token_id,
+ )
+ return dict(input_ids=q_list, labels=a_list)
+
+
+class SupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ preprocessed_path,
+ num_data,
+ ):
+ super(SupervisedDataset, self).__init__()
+
+ # save to file
+ # Make sure only the first process is processing the dataset
+ if dist.get_rank() != 0:
+ dist.barrier()
+ self.preprocessed_path = preprocessed_path
+ if os.path.exists(self.preprocessed_path):
+ logging.warning("loading from preprocessed data")
+ with open(self.preprocessed_path, "r") as f:
+ data_dict = json.load(f)
+ if dist.get_rank() == 0:
+ dist.barrier()
+ else:
+ if not os.path.exists("preprocessed_data"):
+ os.mkdir("preprocessed_data")
+ assert dist.get_rank() == 0, "Only the first process should process"
+ logging.warning("Loading data...")
+ list_data_dict = json.load(open(data_path, "r"))
+
+ logging.warning("Formatting inputs...")
+ sources = []
+
+ sources = [example["conversations"] for example in list_data_dict]
+
+ data_dict = preprocess(sources, tokenizer)
+ json_data_dict = json.dumps(data_dict)
+
+ # Remember to close file to avoid concurrent r/w
+ with open(self.preprocessed_path, "w") as f:
+ f.write(json_data_dict)
+
+ # Release barrier
+ dist.barrier()
+
+ if num_data != -1:
+ data_dict["input_ids"] = data_dict["input_ids"][:num_data]
+ data_dict["labels"] = data_dict["labels"][:num_data]
+
+ # Shuffle data to see more conversations, if only train on partial data
+ temp = list(zip(data_dict["input_ids"], data_dict["labels"]))
+ random.shuffle(temp)
+ res1, res2 = zip(*temp)
+ data_dict["input_ids"], data_dict["labels"] = list(res1), list(res2)
+
+ # Dacheng: Get rid of short QA pair
+ self.input_ids = copy.deepcopy(data_dict["input_ids"])
+ self.labels = copy.deepcopy(data_dict["labels"])
+ length_arr = defaultdict(int)
+ for idx, (input, label) in enumerate(
+ zip(data_dict["input_ids"], data_dict["labels"])
+ ):
+ length_arr[str(len(label) // 100)] += 1
+ if len(input) <= 5:
+ del_idx = self.input_ids.index(input)
+ self.input_ids.pop(del_idx)
+ self.labels.pop(del_idx)
+ if len(label) <= 5:
+ del_idx = self.labels.index(label)
+ self.input_ids.pop(del_idx)
+ self.labels.pop(del_idx)
+
+ for input, label in zip(self.input_ids, self.labels):
+ assert len(input) >= 5
+ assert len(label) >= 5
+
+ def __len__(self):
+ return len(self.input_ids)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple(
+ [
+ torch.as_tensor(instance[key], dtype=torch.int64)
+ for instance in instances
+ ]
+ for key in ("input_ids", "labels")
+ )
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(
+ labels, batch_first=True, padding_value=IGNORE_INDEX
+ )
+ ret = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+ torch.set_printoptions(profile="full")
+ return ret
+
+
+def make_supervised_data_module(
+ tokenizer: transformers.PreTrainedTokenizer, data_args
+) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ dataset_cls = SupervisedDataset
+ train_dataset = dataset_cls(
+ tokenizer=tokenizer,
+ data_path=data_args.data_path,
+ preprocessed_path=data_args.preprocessed_path,
+ num_data=data_args.num_data,
+ )
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(
+ train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
+ )
+
+
+def train():
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments)
+ )
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ )
+ # Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend
+ # a space before special tokens.
+ tokenizer = transformers.T5Tokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
+ other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"],
+ tokenizer=tokenizer,
+ model=model,
+ )
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
+ trainer = Trainer(
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
+ )
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/fastchat/train/train_lora.py b/fastchat/train/train_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ecb47c29fbb21f6e57d9de1cba70002a886d152
--- /dev/null
+++ b/fastchat/train/train_lora.py
@@ -0,0 +1,222 @@
+# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG>
+
+# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+import logging
+import pathlib
+import typing
+import os
+
+from deepspeed import zero
+from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+import transformers
+from transformers import Trainer, BitsAndBytesConfig, deepspeed
+import torch
+
+from fastchat.train.train import (
+ DataArguments,
+ ModelArguments,
+ make_supervised_data_module,
+)
+
+from fastchat.train.llama_flash_attn_monkey_patch import (
+ replace_llama_attn_with_flash_attn,
+)
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: typing.Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ flash_attn: bool = False
+
+
+@dataclass
+class LoraArguments:
+ lora_r: int = 8
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_target_modules: typing.List[str] = field(
+ default_factory=lambda: ["q_proj", "v_proj"]
+ )
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ q_lora: bool = False
+
+
+def maybe_zero_3(param):
+ if hasattr(param, "ds_id"):
+ assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
+ return to_return
+
+
+def train():
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
+ )
+ (
+ model_args,
+ data_args,
+ training_args,
+ lora_args,
+ ) = parser.parse_args_into_dataclasses()
+
+ if training_args.flash_attn:
+ replace_llama_attn_with_flash_attn()
+
+ device_map = None
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ ddp = world_size != 1
+ if lora_args.q_lora:
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
+ if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
+ logging.warning(
+ "FSDP and ZeRO3 are both currently incompatible with QLoRA."
+ )
+
+ compute_dtype = (
+ torch.float16
+ if training_args.fp16
+ else (torch.bfloat16 if training_args.bf16 else torch.float32)
+ )
+
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ device_map=device_map,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=compute_dtype,
+ )
+ if lora_args.q_lora
+ else None,
+ )
+ lora_config = LoraConfig(
+ r=lora_args.lora_r,
+ lora_alpha=lora_args.lora_alpha,
+ target_modules=lora_args.lora_target_modules,
+ lora_dropout=lora_args.lora_dropout,
+ bias=lora_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+
+ if lora_args.q_lora:
+ model = prepare_model_for_kbit_training(
+ model, use_gradient_checkpointing=training_args.gradient_checkpointing
+ )
+ if not ddp and torch.cuda.device_count() > 1:
+ # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
+ model.is_parallelizable = True
+ model.model_parallel = True
+
+ model = get_peft_model(model, lora_config)
+ if training_args.flash_attn:
+ for name, module in model.named_modules():
+ if "norm" in name:
+ module = module.to(compute_dtype)
+ if "lm_head" in name or "embed_tokens" in name:
+ if hasattr(module, "weight"):
+ module = module.to(compute_dtype)
+ if training_args.deepspeed is not None and training_args.local_rank == 0:
+ model.print_trainable_parameters()
+
+ if training_args.gradient_checkpointing:
+ model.enable_input_require_grads()
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+ tokenizer.pad_token = tokenizer.unk_token
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
+ trainer = Trainer(
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
+ )
+
+ model.config.use_cache = False
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+
+ # check if zero3 mode enabled
+ if deepspeed.is_deepspeed_zero3_enabled():
+ # use deepspeed engine internal function to gather state dict
+ # state_dict_zero3 contains whole parameters of base and lora adapters
+ # we will not extract lora parameters since peft save_pretrained will do that
+ # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125
+ # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19
+ state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
+ if training_args.local_rank == 0:
+ state_dict = state_dict_zero3
+ else:
+ # in other mode we use original code from fastchat team, to make sure our change is minimum
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), lora_args.lora_bias
+ )
+
+ if training_args.local_rank == 0:
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/fastchat/train/train_lora_t5.py b/fastchat/train/train_lora_t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..21abc92cb9e64482d4c3375e6321bd00641ac4f9
--- /dev/null
+++ b/fastchat/train/train_lora_t5.py
@@ -0,0 +1,226 @@
+# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+import copy
+import os
+from dataclasses import dataclass, field
+import random
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+
+import torch
+import torch.distributed as dist
+
+
+from deepspeed import zero
+from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
+
+import transformers
+from torch.utils.data import Dataset
+from transformers import Trainer, AddedToken, BitsAndBytesConfig, deepspeed
+
+from fastchat.train.train_flant5 import (
+ smart_tokenizer_and_embedding_resize,
+ make_supervised_data_module,
+)
+
+from fastchat.train.train_lora import get_peft_state_maybe_zero_3
+
+from fastchat.model.model_adapter import get_conversation_template
+
+default_conversation = get_conversation_template("t5")
+
+# TODO: import and use code from ../data/dataset.py
+
+IGNORE_INDEX = -100
+DEFAULT_PAD_TOKEN = "[PAD]"
+DEFAULT_EOS_TOKEN = ""
+DEFAULT_BOS_TOKEN = ""
+DEFAULT_UNK_TOKEN = ""
+
+
+@dataclass
+class LoraArguments:
+ lora_r: int = 8
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_target_modules: List[str] = field(default_factory=lambda: ["q", "v"])
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ q_lora: bool = False
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(
+ default=None, metadata={"help": "Path to the training data."}
+ )
+ lazy_preprocess: bool = False
+ num_data: int = -1
+ preprocessed_path: str = field(
+ default=None, metadata={"help": "Path to the preprocessed training data."}
+ )
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ model_max_length: int = field(
+ default=2048,
+ metadata={
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+
+
+def safe_save_model_for_hf_trainer(
+ trainer: transformers.Trainer, output_dir: str, state_dict: dict
+):
+ """Collects the state dict and dump to disk."""
+
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def train():
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
+ )
+ (
+ model_args,
+ data_args,
+ training_args,
+ lora_args,
+ ) = parser.parse_args_into_dataclasses()
+
+ device_map = None
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ ddp = world_size != 1
+ if lora_args.q_lora:
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
+ if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
+ logging.warning(
+ "FSDP and ZeRO3 are both currently incompatible with QLoRA."
+ )
+
+ compute_dtype = (
+ torch.float16
+ if training_args.fp16
+ else (torch.bfloat16 if training_args.bf16 else torch.float32)
+ )
+
+ model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ device_map=device_map,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=compute_dtype,
+ )
+ if lora_args.q_lora
+ else None,
+ )
+
+ lora_config = LoraConfig(
+ r=lora_args.lora_r,
+ lora_alpha=lora_args.lora_alpha,
+ target_modules=lora_args.lora_target_modules,
+ lora_dropout=lora_args.lora_dropout,
+ bias=lora_args.lora_bias,
+ task_type=TaskType.SEQ_2_SEQ_LM,
+ )
+
+ if lora_args.q_lora:
+ model = prepare_model_for_kbit_training(
+ model, use_gradient_checkpointing=training_args.gradient_checkpointing
+ )
+ if not ddp and torch.cuda.device_count() > 1:
+ # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
+ model.is_parallelizable = True
+ model.model_parallel = True
+
+ model = get_peft_model(model, lora_config)
+ if training_args.deepspeed is not None and training_args.local_rank == 0:
+ model.print_trainable_parameters()
+
+ if training_args.gradient_checkpointing:
+ model.enable_input_require_grads()
+
+ # Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend
+ # a space before special tokens.
+ tokenizer = transformers.T5Tokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
+ other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"],
+ tokenizer=tokenizer,
+ model=model,
+ )
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
+
+ trainer = Trainer(
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
+ )
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+ # check if zero3 mode enabled
+ if deepspeed.is_deepspeed_zero3_enabled():
+ # use deepspeed engine internal function to gather state dict
+ # state_dict_zero3 contains whole parameters of base and lora adapters
+ # we will not extract lora parameters since peft save_pretrained will do that
+ # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125
+ # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19
+ state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
+ if training_args.local_rank == 0:
+ state_dict = state_dict_zero3
+ else:
+ # in other mode we use original code from fastchat team, to make sure our change is minimum
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), lora_args.lora_bias
+ )
+
+ if training_args.local_rank == 0:
+ safe_save_model_for_hf_trainer(
+ trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict
+ )
+
+
+if __name__ == "__main__":
+ train()
diff --git a/fastchat/train/train_mem.py b/fastchat/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ce4913aae3ef2080470161724a4f7127abb11f0
--- /dev/null
+++ b/fastchat/train/train_mem.py
@@ -0,0 +1,13 @@
+# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
+
+# Need to call this before importing transformers.
+from fastchat.train.llama2_flash_attn_monkey_patch import (
+ replace_llama_attn_with_flash_attn,
+)
+
+replace_llama_attn_with_flash_attn()
+
+from fastchat.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/fastchat/train/train_xformers.py b/fastchat/train/train_xformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eb2badd59140d72ff995ad4419fde2a2a697955
--- /dev/null
+++ b/fastchat/train/train_xformers.py
@@ -0,0 +1,13 @@
+# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
+
+# Need to call this before importing transformers.
+from fastchat.train.llama_xformers_attn_monkey_patch import (
+ replace_llama_attn_with_xformers_attn,
+)
+
+replace_llama_attn_with_xformers_attn()
+
+from fastchat.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/fastchat/utils.py b/fastchat/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e3ba543a0a6b9d6e0ca0fd853bd30db9d205d3
--- /dev/null
+++ b/fastchat/utils.py
@@ -0,0 +1,349 @@
+"""
+Common utilities.
+"""
+from asyncio import AbstractEventLoop
+import json
+import logging
+import logging.handlers
+import os
+import platform
+import sys
+from typing import AsyncGenerator, Generator
+import warnings
+
+import requests
+
+from fastchat.constants import LOGDIR
+
+
+handler = None
+visited_loggers = set()
+
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ if sys.version_info[1] >= 9:
+ # This is for windows
+ logging.basicConfig(level=logging.INFO, encoding="utf-8")
+ else:
+ if platform.system() == "Windows":
+ warnings.warn(
+ "If you are running on Windows, "
+ "we recommend you use Python >= 3.9 for UTF-8 encoding."
+ )
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # if LOGDIR is empty, then don't try output log to local file
+ if LOGDIR != "":
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when="D", utc=True, encoding="utf-8"
+ )
+ handler.setFormatter(formatter)
+
+ for l in [stdout_logger, stderr_logger, logger]:
+ if l in visited_loggers:
+ continue
+ visited_loggers.add(l)
+ l.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ""
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ""
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == "\n":
+ encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
+ self.logger.log(self.log_level, encoded_message.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != "":
+ encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
+ self.logger.log(self.log_level, encoded_message.rstrip())
+ self.linebuf = ""
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def get_gpu_memory(max_gpus=None):
+ """Get available memory for each GPU."""
+ import torch
+
+ gpu_memory = []
+ num_gpus = (
+ torch.cuda.device_count()
+ if max_gpus is None
+ else min(max_gpus, torch.cuda.device_count())
+ )
+
+ for gpu_id in range(num_gpus):
+ with torch.cuda.device(gpu_id):
+ device = torch.cuda.current_device()
+ gpu_properties = torch.cuda.get_device_properties(device)
+ total_memory = gpu_properties.total_memory / (1024**3)
+ allocated_memory = torch.cuda.memory_allocated() / (1024**3)
+ available_memory = total_memory - allocated_memory
+ gpu_memory.append(available_memory)
+ return gpu_memory
+
+
+def oai_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ import openai
+
+ openai.api_base = "https://api.openai.com/v1"
+ openai.api_key = os.environ["OPENAI_API_KEY"]
+
+ MAX_RETRY = 3
+ for i in range(MAX_RETRY):
+ try:
+ res = openai.Moderation.create(input=text)
+ flagged = res["results"][0]["flagged"]
+ break
+ except (openai.error.OpenAIError, KeyError, IndexError) as e:
+ # flag true to be conservative
+ flagged = True
+ print(f"MODERATION ERROR: {e}\nInput: {text}")
+ return flagged
+
+
+def moderation_filter(text, model_list):
+ MODEL_KEYWORDS = ["claude"]
+
+ for keyword in MODEL_KEYWORDS:
+ for model in model_list:
+ if keyword in model and oai_moderation(text):
+ return True
+ return False
+
+
+def clean_flant5_ckpt(ckpt_path):
+ """
+ Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
+ Use this function to make sure it can be correctly loaded.
+ """
+ import torch
+
+ index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
+ index_json = json.load(open(index_file, "r"))
+
+ weightmap = index_json["weight_map"]
+
+ share_weight_file = weightmap["shared.weight"]
+ share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[
+ "shared.weight"
+ ]
+
+ for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]:
+ weight_file = weightmap[weight_name]
+ weight = torch.load(os.path.join(ckpt_path, weight_file))
+ weight[weight_name] = share_weight
+ torch.save(weight, os.path.join(ckpt_path, weight_file))
+
+
+def pretty_print_semaphore(semaphore):
+ """Print a semaphore in better format."""
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
+
+
+"""A javascript function to get url parameters for the gradio web server."""
+get_window_url_params_js = """
+function() {
+ const params = new URLSearchParams(window.location.search);
+ url_params = Object.fromEntries(params);
+ console.log("url_params", url_params);
+ return url_params;
+ }
+"""
+
+
+get_window_url_params_with_tos_js = """
+function() {
+ const params = new URLSearchParams(window.location.search);
+ url_params = Object.fromEntries(params);
+ console.log("url_params", url_params);
+
+ msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nThe service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license."
+ alert(msg);
+
+ return url_params;
+ }
+"""
+
+
+def iter_over_async(
+ async_gen: AsyncGenerator, event_loop: AbstractEventLoop
+) -> Generator:
+ """
+ Convert async generator to sync generator
+
+ :param async_gen: the AsyncGenerator to convert
+ :param event_loop: the event loop to run on
+ :returns: Sync generator
+ """
+ ait = async_gen.__aiter__()
+
+ async def get_next():
+ try:
+ obj = await ait.__anext__()
+ return False, obj
+ except StopAsyncIteration:
+ return True, None
+
+ while True:
+ done, obj = event_loop.run_until_complete(get_next())
+ if done:
+ break
+ yield obj
+
+
+def detect_language(text: str) -> str:
+ """Detect the langauge of a string."""
+ import polyglot # pip3 install polyglot pyicu pycld2
+ from polyglot.detect import Detector
+ from polyglot.detect.base import logger as polyglot_logger
+ import pycld2
+
+ polyglot_logger.setLevel("ERROR")
+
+ try:
+ lang_code = Detector(text).language.name
+ except (pycld2.error, polyglot.detect.base.UnknownLanguage):
+ lang_code = "unknown"
+ return lang_code
+
+
+def parse_gradio_auth_creds(filename: str):
+ """Parse a username:password file for gradio authorization."""
+ gradio_auth_creds = []
+ with open(filename, "r", encoding="utf8") as file:
+ for line in file.readlines():
+ gradio_auth_creds += [x.strip() for x in line.split(",") if x.strip()]
+ if gradio_auth_creds:
+ auth = [tuple(cred.split(":")) for cred in gradio_auth_creds]
+ else:
+ auth = None
+ return auth
+
+
+def is_partial_stop(output: str, stop_str: str):
+ """Check whether the output contains a partial stop str."""
+ for i in range(0, min(len(output), len(stop_str))):
+ if stop_str.startswith(output[-i:]):
+ return True
+ return False
+
+
+def run_cmd(cmd: str):
+ """Run a bash command."""
+ print(cmd)
+ return os.system(cmd)
+
+
+def is_sentence_complete(output: str):
+ """Check whether the output is a complete sentence."""
+ end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
+ return output.endswith(end_symbols)
+
+
+# Models don't use the same configuration key for determining the maximum
+# sequence length. Store them here so we can sanely check them.
+# NOTE: The ordering here is important. Some models have two of these and we
+# have a preference for which value gets used.
+SEQUENCE_LENGTH_KEYS = [
+ "max_sequence_length",
+ "seq_length",
+ "max_position_embeddings",
+ "max_seq_len",
+ "model_max_length",
+]
+
+
+def get_context_length(config):
+ """Get the context length of a model from a huggingface model config."""
+ rope_scaling = getattr(config, "rope_scaling", None)
+ if rope_scaling:
+ rope_scaling_factor = config.rope_scaling["factor"]
+ else:
+ rope_scaling_factor = 1
+
+ for key in SEQUENCE_LENGTH_KEYS:
+ val = getattr(config, key, None)
+ if val is not None:
+ return int(rope_scaling_factor * val)
+ return 2048
+
+
+def str_to_torch_dtype(dtype: str):
+ import torch
+
+ if dtype is None:
+ return None
+ elif dtype == "float32":
+ return torch.float32
+ elif dtype == "float16":
+ return torch.float16
+ elif dtype == "bfloat16":
+ return torch.bfloat16
+ else:
+ raise ValueError(f"Unrecognized dtype: {dtype}")