Husnain commited on
Commit
61ebb7f
1 Parent(s): 03fe2ac

⚡ [Enhance] Use nous-mixtral-8x7b as default model

Browse files
Files changed (1) hide show
  1. messagers/token_checker.py +46 -0
messagers/token_checker.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tclogger import logger
2
+ from transformers import AutoTokenizer
3
+
4
+ from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED
5
+
6
+
7
+ class TokenChecker:
8
+ def __init__(self, input_str: str, model: str):
9
+ self.input_str = input_str
10
+
11
+ if model in MODEL_MAP.keys():
12
+ self.model = model
13
+ else:
14
+ self.model = "nous-mixtral-8x7b"
15
+
16
+ self.model_fullname = MODEL_MAP[self.model]
17
+
18
+ # As some models are gated, we need to fetch tokenizers from alternatives
19
+ GATED_MODEL_MAP = {
20
+ "llama3-70b": "NousResearch/Meta-Llama-3-70B",
21
+ "gemma-7b": "unsloth/gemma-7b",
22
+ "mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2",
23
+ "mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1",
24
+ }
25
+ if self.model in GATED_MODEL_MAP.keys():
26
+ self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model])
27
+ else:
28
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
29
+
30
+ def count_tokens(self):
31
+ token_count = len(self.tokenizer.encode(self.input_str))
32
+ logger.note(f"Prompt Token Count: {token_count}")
33
+ return token_count
34
+
35
+ def get_token_limit(self):
36
+ return TOKEN_LIMIT_MAP[self.model]
37
+
38
+ def get_token_redundancy(self):
39
+ return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens())
40
+
41
+ def check_token_limit(self):
42
+ if self.get_token_redundancy() <= 0:
43
+ raise ValueError(
44
+ f"Prompt exceeded token limit: {self.count_tokens()} > {self.get_token_limit()}"
45
+ )
46
+ return True