Miaoran000 commited on
Commit
0544709
·
1 Parent(s): fd9d58a

update model_operations.py for new llms

Browse files
Files changed (1) hide show
  1. src/backend/model_operations.py +44 -16
src/backend/model_operations.py CHANGED
@@ -164,7 +164,7 @@ class SummaryGenerator:
164
  using_replicate_api = False
165
  replicate_api_models = ['snowflake', 'llama-3.1-405b']
166
  using_pipeline = False
167
- pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo']
168
 
169
  for replicate_api_model in replicate_api_models:
170
  if replicate_api_model in self.model_id.lower():
@@ -222,6 +222,7 @@ class SummaryGenerator:
222
  print(result)
223
  return result
224
 
 
225
  elif 'grok' in self.model_id.lower(): # xai
226
  XAI_API_KEY = os.getenv("XAI_API_KEY")
227
  client = OpenAI(
@@ -241,6 +242,7 @@ class SummaryGenerator:
241
  print(result)
242
  return result
243
 
 
244
  elif 'gemini' in self.model_id.lower():
245
  vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
246
  model = GenerativeModel(
@@ -249,7 +251,7 @@ class SummaryGenerator:
249
  )
250
  generation_config = {
251
  "temperature": 0,
252
- "max_output_tokens": 250
253
  }
254
  safety_settings = [
255
  SafetySetting(
@@ -277,6 +279,8 @@ class SummaryGenerator:
277
  result = response.text
278
  print(result)
279
  return result
 
 
280
  elif using_replicate_api:
281
  print("using replicate")
282
  if 'snowflake' in self.model_id.lower():
@@ -306,6 +310,7 @@ class SummaryGenerator:
306
  print(response)
307
  return response
308
 
 
309
  elif 'claude' in self.model_id.lower(): # using anthropic api
310
  print('using Anthropic API')
311
  client = anthropic.Anthropic()
@@ -331,6 +336,7 @@ class SummaryGenerator:
331
  print(result)
332
  return result
333
 
 
334
  elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower():
335
  co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN'))
336
  response = co.chat(
@@ -345,6 +351,7 @@ class SummaryGenerator:
345
  print(result)
346
  return result
347
 
 
348
  elif 'mistral-large' in self.model_id.lower():
349
  api_key = os.environ["MISTRAL_API_KEY"]
350
  client = Mistral(api_key=api_key)
@@ -369,6 +376,7 @@ class SummaryGenerator:
369
  print(result)
370
  return result
371
 
 
372
  elif 'deepseek' in self.model_id.lower():
373
  client = OpenAI(api_key=os.getenv("DeepSeek_API_KEY"), base_url="https://api.deepseek.com")
374
  response = client.chat.completions.create(
@@ -385,20 +393,21 @@ class SummaryGenerator:
385
  print(result)
386
  return result
387
 
388
- # Using HF API or download checkpoints
389
  elif self.local_model is None and self.local_pipeline is None:
390
  if using_pipeline:
391
  self.local_pipeline = pipeline(
392
  "text-generation",
393
  model=self.model_id,
394
  tokenizer=AutoTokenizer.from_pretrained(self.model_id),
395
- torch_dtype=torch.bfloat16 if 'llama-3.2' in self.model_id.lower() else "auto",
396
  device_map="auto",
397
  trust_remote_code=True
398
  )
399
  else:
400
  if 'ragamuffin' in self.model_id.lower():
401
  self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id))
 
402
  else:
403
  self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
404
  print("Tokenizer loaded")
@@ -420,7 +429,12 @@ class SummaryGenerator:
420
  # self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id),
421
  # torch_dtype=torch.bfloat16, # forcing bfloat16 for now
422
  # attn_implementation="flash_attention_2")
423
-
 
 
 
 
 
424
  else:
425
  self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto"
426
  # print(self.local_model.device)
@@ -435,7 +449,7 @@ class SummaryGenerator:
435
  ]
436
  outputs = self.local_pipeline(
437
  messages,
438
- max_new_tokens=250,
439
  # return_full_text=False,
440
  do_sample=False
441
  )
@@ -445,6 +459,8 @@ class SummaryGenerator:
445
 
446
  elif self.local_model: # cannot call API. using local model / pipeline
447
  print('Using local model')
 
 
448
  if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower():
449
  messages=[
450
  # gemma-1.1, mistral-7b does not accept system role
@@ -478,29 +494,41 @@ class SummaryGenerator:
478
  {"role": "system", "content": system_prompt},
479
  {"role": "user", "content": user_prompt}
480
  ]
481
- prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
482
- # print(prompt)
483
- # print('-'*50)
484
- input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
485
  if 'granite' in self.model_id.lower():
486
  self.local_model.eval()
487
  outputs = self.local_model.generate(**input_ids, max_new_tokens=250)
 
 
 
 
488
  else:
489
  with torch.no_grad():
490
  outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id
491
  if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower():
492
  outputs = outputs[:, input_ids['input_ids'].shape[1]:]
493
- elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
494
  outputs = [
495
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
496
  ]
497
-
 
498
  if 'qwen2-vl' in self.model_id.lower():
499
  result = self.processor.batch_decode(
500
  outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
501
  )[0]
502
- # elif 'granite' in self.model_id.lower():
503
- # result = self.tokenizer.batch_decode(outputs)[0]
504
  else:
505
  result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
506
 
@@ -512,9 +540,9 @@ class SummaryGenerator:
512
  result = result.split(messages[-1]['content'])[1].strip()
513
  elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
514
  pass
 
 
515
  else:
516
- # print(prompt)
517
- # print('-'*50)
518
  result = result.replace(prompt.strip(), '')
519
 
520
  print(result)
 
164
  using_replicate_api = False
165
  replicate_api_models = ['snowflake', 'llama-3.1-405b']
166
  using_pipeline = False
167
+ pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo', 'llama-3.3']
168
 
169
  for replicate_api_model in replicate_api_models:
170
  if replicate_api_model in self.model_id.lower():
 
222
  print(result)
223
  return result
224
 
225
+ # Using Grok API
226
  elif 'grok' in self.model_id.lower(): # xai
227
  XAI_API_KEY = os.getenv("XAI_API_KEY")
228
  client = OpenAI(
 
242
  print(result)
243
  return result
244
 
245
+ # Using Vertex AI API for Gemini models
246
  elif 'gemini' in self.model_id.lower():
247
  vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
248
  model = GenerativeModel(
 
251
  )
252
  generation_config = {
253
  "temperature": 0,
254
+ "max_output_tokens": 500
255
  }
256
  safety_settings = [
257
  SafetySetting(
 
279
  result = response.text
280
  print(result)
281
  return result
282
+
283
+ # Using Replicate API
284
  elif using_replicate_api:
285
  print("using replicate")
286
  if 'snowflake' in self.model_id.lower():
 
310
  print(response)
311
  return response
312
 
313
+ # Using Anthropic API for Claude models
314
  elif 'claude' in self.model_id.lower(): # using anthropic api
315
  print('using Anthropic API')
316
  client = anthropic.Anthropic()
 
336
  print(result)
337
  return result
338
 
339
+ # Using Cohere API
340
  elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower():
341
  co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN'))
342
  response = co.chat(
 
351
  print(result)
352
  return result
353
 
354
+ # Using MistralAI API
355
  elif 'mistral-large' in self.model_id.lower():
356
  api_key = os.environ["MISTRAL_API_KEY"]
357
  client = Mistral(api_key=api_key)
 
376
  print(result)
377
  return result
378
 
379
+ # Using Deepseek API
380
  elif 'deepseek' in self.model_id.lower():
381
  client = OpenAI(api_key=os.getenv("DeepSeek_API_KEY"), base_url="https://api.deepseek.com")
382
  response = client.chat.completions.create(
 
393
  print(result)
394
  return result
395
 
396
+ # Using HF pipeline or local checkpoints
397
  elif self.local_model is None and self.local_pipeline is None:
398
  if using_pipeline:
399
  self.local_pipeline = pipeline(
400
  "text-generation",
401
  model=self.model_id,
402
  tokenizer=AutoTokenizer.from_pretrained(self.model_id),
403
+ torch_dtype=torch.bfloat16 if 'llama-3.2' in self.model_id.lower() or 'llama-3.3' in self.model_id.lower() else "auto",
404
  device_map="auto",
405
  trust_remote_code=True
406
  )
407
  else:
408
  if 'ragamuffin' in self.model_id.lower():
409
  self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id))
410
+
411
  else:
412
  self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
413
  print("Tokenizer loaded")
 
429
  # self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id),
430
  # torch_dtype=torch.bfloat16, # forcing bfloat16 for now
431
  # attn_implementation="flash_attention_2")
432
+ elif 'olmo' in self.model_id.lower():
433
+ self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id)#torch_dtype="auto"
434
+
435
+ elif 'qwq-' in self.model_id.lower():
436
+ self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype="auto", device_map="auto")
437
+
438
  else:
439
  self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto"
440
  # print(self.local_model.device)
 
449
  ]
450
  outputs = self.local_pipeline(
451
  messages,
452
+ max_new_tokens=256,
453
  # return_full_text=False,
454
  do_sample=False
455
  )
 
459
 
460
  elif self.local_model: # cannot call API. using local model / pipeline
461
  print('Using local model')
462
+
463
+ # Set appropriate prompt based on model document
464
  if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower():
465
  messages=[
466
  # gemma-1.1, mistral-7b does not accept system role
 
494
  {"role": "system", "content": system_prompt},
495
  {"role": "user", "content": user_prompt}
496
  ]
497
+ prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
498
+
499
+ # Tokenize inputs
500
+ if 'olmo' in self.model_id.lower():
501
+ input_ids = self.tokenizer([prompt], return_tensors='pt', return_token_type_ids=False)#.to(self.device)
502
+ elif 'qwq' in self.model_id.lower():
503
+ input_ids = self.tokenizer([prompt], return_tensors="pt").to(self.device)
504
+ else:
505
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
506
+
507
+ # Generate outputs
508
  if 'granite' in self.model_id.lower():
509
  self.local_model.eval()
510
  outputs = self.local_model.generate(**input_ids, max_new_tokens=250)
511
+ elif 'olmo' in self.model_id.lower():
512
+ outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01)#top_k=50, top_p=0.95)
513
+ elif 'qwq' in self.model_id.lower():
514
+ outputs = self.local_model.generate(**input_ids, max_new_tokens=512, do_sample=True, temperature=0.01)
515
  else:
516
  with torch.no_grad():
517
  outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id
518
  if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower():
519
  outputs = outputs[:, input_ids['input_ids'].shape[1]:]
520
+ elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower() or 'qwq-' in self.model_id.lower():
521
  outputs = [
522
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
523
  ]
524
+
525
+ # Decode outputs
526
  if 'qwen2-vl' in self.model_id.lower():
527
  result = self.processor.batch_decode(
528
  outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
529
  )[0]
530
+ elif 'olmo' in self.model_id.lower() or 'qwq' in self.model_id.lower():
531
+ result = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
532
  else:
533
  result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
534
 
 
540
  result = result.split(messages[-1]['content'])[1].strip()
541
  elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
542
  pass
543
+ elif 'olmo' in self.model_id.lower():
544
+ result = result.split("<|assistant|>\n")[-1]
545
  else:
 
 
546
  result = result.replace(prompt.strip(), '')
547
 
548
  print(result)