JosephusCheung commited on
Commit
e64ca07
1 Parent(s): 5348a89

Upload evaluate_chatml_mmlu.py

Browse files
Files changed (1) hide show
  1. eval/evaluate_chatml_mmlu.py +391 -0
eval/evaluate_chatml_mmlu.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+ import re
8
+ from thefuzz import process
9
+ from typing import List
10
+ from tqdm import tqdm
11
+ from transformers.trainer_utils import set_seed
12
+
13
+ '''
14
+ wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
15
+ mkdir data/mmlu
16
+ mv data.tar data/mmlu
17
+ cd data/mmlu; tar xf data.tar
18
+ cd ../../
19
+
20
+ pip install thefuzz
21
+ python eval/evaluate_chat_mmlu.py -d data/mmlu/data/
22
+ '''
23
+ from typing import Tuple, List, Union, Iterable
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from transformers import PreTrainedTokenizer
29
+ from transformers import logging
30
+ from transformers.generation import LogitsProcessor
31
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
32
+ HistoryType = List[Tuple[str, str]]
33
+ TokensType = List[int]
34
+ BatchTokensType = List[List[int]]
35
+
36
+ def make_context(
37
+ tokenizer: PreTrainedTokenizer,
38
+ query: str,
39
+ history: List[Tuple[str, str]] = None,
40
+ system: str = "",
41
+ max_window_size: int = 6144,
42
+ chat_format: str = "chatml",
43
+ ):
44
+ if history is None:
45
+ history = []
46
+
47
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
48
+ im_start_tokens = [tokenizer.im_start_id]
49
+ im_end_tokens = [tokenizer.im_end_id]
50
+ nl_tokens = tokenizer.encode("\n")
51
+
52
+ def _tokenize_str(role, content):
53
+ return f"{role}\n{content}", tokenizer.encode(
54
+ role
55
+ ) + nl_tokens + tokenizer.encode(content)
56
+
57
+ system_text, system_tokens_part = _tokenize_str("system", system)
58
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
59
+
60
+ raw_text = ""
61
+ context_tokens = []
62
+
63
+ for turn_query, turn_response in reversed(history):
64
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
65
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
66
+ response_text, response_tokens_part = _tokenize_str(
67
+ "assistant", turn_response
68
+ )
69
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
70
+
71
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
72
+ prev_chat = (
73
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
74
+ )
75
+
76
+ current_context_size = (
77
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
78
+ )
79
+ if current_context_size < max_window_size:
80
+ context_tokens = next_context_tokens + context_tokens
81
+ raw_text = prev_chat + raw_text
82
+ else:
83
+ break
84
+
85
+ context_tokens = system_tokens + context_tokens
86
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
87
+ context_tokens += (
88
+ nl_tokens
89
+ + im_start_tokens
90
+ + _tokenize_str("user", query)[1]
91
+ + im_end_tokens
92
+ + nl_tokens
93
+ + im_start_tokens
94
+ + tokenizer.encode("assistant")
95
+ + nl_tokens
96
+ )
97
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
98
+
99
+ return raw_text, context_tokens
100
+
101
+ def chat(
102
+ model,
103
+ tokenizer: PreTrainedTokenizer,
104
+ query: str,
105
+ history: Optional[HistoryType],
106
+ system: str = "You are a helpful assistant.",
107
+ append_history: bool = True
108
+ ) -> Tuple[str, HistoryType]:
109
+
110
+
111
+ if history is None:
112
+ history = []
113
+
114
+ raw_text, context_tokens = make_context(
115
+ tokenizer,
116
+ query,
117
+ history=history,
118
+ system=system,
119
+ max_window_size=6144,
120
+ chat_format = "chatml",
121
+ )
122
+
123
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
124
+ input_ids = torch.tensor([context_tokens]).cuda()
125
+ outputs = model.generate(
126
+ input_ids,
127
+ stop_words_ids = stop_words_ids,
128
+ return_dict_in_generate = False,
129
+ )
130
+
131
+ response = decode_tokens(
132
+ outputs[0],
133
+ tokenizer,
134
+ raw_text_len=len(raw_text),
135
+ context_length=len(context_tokens),
136
+ chat_format='chatml',
137
+ verbose=False,
138
+ )
139
+
140
+ if append_history:
141
+ history.append((query, response))
142
+
143
+ return response, history
144
+
145
+ def decode_tokens(
146
+ tokens: Union[torch.LongTensor, TokensType],
147
+ tokenizer: PreTrainedTokenizer,
148
+ raw_text_len: int,
149
+ context_length: int,
150
+ chat_format: str = "chatml",
151
+ verbose: bool = False,
152
+ return_end_reason: bool = False,
153
+ ) -> str:
154
+ if torch.is_tensor(tokens):
155
+ tokens = tokens.cpu().numpy().tolist()
156
+
157
+
158
+ return _decode_chatml(
159
+ tokens,
160
+ stop_words=[],
161
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
162
+ tokenizer=tokenizer,
163
+ raw_text_len=raw_text_len,
164
+ context_length=context_length,
165
+ verbose=verbose,
166
+ return_end_reason=return_end_reason,
167
+ )
168
+
169
+
170
+ def _decode_chatml(
171
+ tokens: List[int],
172
+ *,
173
+ stop_words: List[str],
174
+ eod_token_ids: List[int],
175
+ tokenizer: PreTrainedTokenizer,
176
+ raw_text_len: int,
177
+ context_length: int,
178
+ verbose: bool = False,
179
+ return_end_reason: bool = False,
180
+ chat_format = "chatml",
181
+ ):
182
+ end_reason = f"Gen length {len(tokens)}"
183
+ eod_token_idx = context_length
184
+ for eod_token_idx in range(context_length, len(tokens)):
185
+ if tokens[eod_token_idx] in eod_token_ids:
186
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
187
+ break
188
+
189
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx])[raw_text_len:]
190
+ if verbose:
191
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens)[raw_text_len:])
192
+ print("\nRaw Generate:", trim_decode_tokens)
193
+ print("\nEnd Reason:", end_reason)
194
+ for stop_word in stop_words:
195
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
196
+ trim_decode_tokens = trim_decode_tokens.strip()
197
+ if verbose:
198
+ print("\nGenerate:", trim_decode_tokens)
199
+
200
+ if return_end_reason:
201
+ return trim_decode_tokens, end_reason
202
+ else:
203
+ return trim_decode_tokens
204
+
205
+
206
+
207
+ def load_models_tokenizer(args):
208
+ from transformers import AutoModelForCausalLM, AutoTokenizer
209
+ from transformers.generation import GenerationConfig
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
212
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
213
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
214
+ model.generation_config.do_sample = False # use greedy decoding
215
+ return model, tokenizer
216
+
217
+
218
+ def format_example(line):
219
+ example = 'The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n' + line['question'] + "\n"
220
+ for choice in choices:
221
+ example += f'{choice}. {line[f"{choice}"]}\n'
222
+ return example
223
+
224
+
225
+ def process_before_extraction(gen, choice_dict):
226
+ # replace the choice by letter in the generated sentence
227
+ # from longest one to shortest one
228
+ for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True):
229
+ pattern = re.compile(re.escape(val.rstrip(".")), re.IGNORECASE)
230
+ gen = pattern.sub(key, gen)
231
+ return gen
232
+
233
+ def extract_choice(gen, choice_list):
234
+ # answer is A | choice is A | choose A
235
+ res = re.search(r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b", gen)
236
+
237
+ # A is correct | A is right
238
+ if res is None:
239
+ res = re.search(r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b", gen)
240
+
241
+ # straight answer: A
242
+ if res is None:
243
+ res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen)
244
+
245
+ # simply extract the first appearred letter
246
+ if res is None:
247
+ res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
248
+
249
+ if res is None:
250
+ return choices[choice_list.index(process.extractOne(gen, choice_list)[0])]
251
+ else:
252
+ return res.group(1)
253
+
254
+ def extract_answer(response, row):
255
+ gen = process_before_extraction(response, {choice: row[choice] for choice in choices})
256
+ pred = extract_choice(gen, [row[choice] for choice in choices])
257
+ return pred
258
+
259
+ @torch.no_grad()
260
+ def eval_subject(
261
+ model,
262
+ tokenizer,
263
+ subject_name,
264
+ test_df,
265
+ save_result_dir=None,
266
+ overwrite=False,
267
+ **kwargs
268
+ ):
269
+ result_path = os.path.join(save_result_dir, f'{subject_name}_result.csv')
270
+ if not overwrite and os.path.exists(result_path):
271
+ print(f"{result_path} existed, skip!")
272
+ score = []
273
+ for (_, datarow), (_, resultrow) in zip(test_df.iterrows(), pd.read_csv(result_path).iterrows()):
274
+ # pred = extract_answer(resultrow['model_response'], datarow)
275
+ pred = resultrow['model_output']
276
+ correct = 1 if pred == datarow['answer'] else 0
277
+ score.append(correct)
278
+ return score
279
+
280
+ result = []
281
+ score = []
282
+
283
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
284
+ question = format_example(row)
285
+
286
+ response, history = chat(
287
+ model,
288
+ tokenizer,
289
+ question,
290
+ history=None,
291
+ )
292
+ print(question)
293
+ print(response)
294
+ pred = extract_answer(response, row)
295
+ print(pred)
296
+ print("======================")
297
+
298
+ if 'answer' in row:
299
+ correct = 1 if pred == row['answer'] else 0
300
+ score.append(correct)
301
+ if args.debug: print(f'{question} pred: {pred} ref: {row["answer"]}')
302
+ result.append(pred)
303
+
304
+ if save_result_dir:
305
+ test_df['model_output'] = result
306
+ test_df['model_response'] = response
307
+ if score:
308
+ test_df["correctness"] = score
309
+ os.makedirs(save_result_dir, exist_ok=True)
310
+ test_df.to_csv(os.path.join(
311
+ save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)
312
+
313
+ return score
314
+
315
+
316
+ def cal_mmlu(res):
317
+ acc_sum_dict = dict()
318
+ acc_norm_sum_dict = dict()
319
+ cnt_dict = dict()
320
+ acc_sum = 0.
321
+ cnt = 0
322
+ hard_cnt = 0
323
+ hard_acc_sum = 0.
324
+
325
+ for class_ in TASK_NAME_MAPPING.keys():
326
+ acc_sum_dict[class_] = 0.
327
+ acc_norm_sum_dict[class_] = 0.
328
+ cnt_dict[class_] = 0.
329
+
330
+ for tt in TASK_NAME_MAPPING[class_]:
331
+ acc_sum += sum(res[tt])
332
+ cnt += len(res[tt])
333
+
334
+ acc_sum_dict[class_] += sum(res[tt])
335
+ cnt_dict[class_] += len(res[tt])
336
+
337
+ print('\n\n\n')
338
+ for k in TASK_NAME_MAPPING.keys():
339
+ if k in cnt_dict:
340
+ print('%s ACC: %.2f ' % (
341
+ k, acc_sum_dict[k] * 100 / cnt_dict[k]))
342
+ print('AVERAGE ACC:%.2f ' % (acc_sum *100 / cnt))
343
+
344
+
345
+ def main(args):
346
+ print("loading model weights")
347
+ if args.checkpoint_path is not None:
348
+ model, tokenizer = load_models_tokenizer(args)
349
+ else:
350
+ model, tokenizer = None, None
351
+ print("model loaded")
352
+
353
+ dev_result = {}
354
+ for subject_name in tqdm(SUBJECTS):
355
+ # val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
356
+ # dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
357
+ test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
358
+ # val_df = pd.read_csv(val_file_path, names=['question','A','B','C','D','answer'])
359
+ # dev_df = pd.read_csv(dev_file_path, names=['question','A','B','C','D','answer'])
360
+ test_df = pd.read_csv(test_file_path, names=['question','A','B','C','D','answer'])
361
+
362
+ score = eval_subject(model, tokenizer, subject_name, test_df, save_result_dir=f"outs_chat/mmlu_eval_result", overwrite=args.overwrite)
363
+ dev_result[subject_name] = score
364
+ cal_mmlu(dev_result)
365
+
366
+
367
+ TASK_NAME_MAPPING = {'stem': ['abstract_algebra', 'anatomy', 'astronomy', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security', 'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics', 'high_school_physics', 'high_school_statistics', 'machine_learning'],
368
+ 'Humanities': ['formal_logic', 'high_school_european_history', 'high_school_us_history', 'high_school_world_history', 'international_law', 'jurisprudence', 'logical_fallacies', 'moral_disputes', 'moral_scenarios', 'philosophy', 'prehistory', 'professional_law', 'world_religions'],
369
+ 'other': ['business_ethics', 'college_medicine', 'human_aging', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting', 'professional_medicine', 'virology', 'global_facts', 'clinical_knowledge'],
370
+ 'social': ['econometrics', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_microeconomics', 'high_school_psychology', 'human_sexuality', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy']}
371
+ SUBJECTS = [v for vl in TASK_NAME_MAPPING.values() for v in vl]
372
+ choices = ["A", "B", "C", "D"]
373
+
374
+ if __name__ == '__main__':
375
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
376
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B-Chat")
377
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
378
+
379
+ """Provide extra arguments required for tasks."""
380
+ group = parser.add_argument_group(title='Evaluation options')
381
+ group.add_argument('-d', '--eval_data_path', type=str,
382
+ help='Path to eval data')
383
+ group.add_argument("--debug", action='store_true', default=False,
384
+ help='Print infos.')
385
+ group.add_argument("--overwrite", action='store_true', default=False,
386
+ help='Overwrite existed results')
387
+
388
+ args = parser.parse_args()
389
+ set_seed(args.seed)
390
+
391
+ main(args)