import os import time from datetime import datetime import logging from pathlib import Path import requests import json import numpy as np import pandas as pd import spacy import litellm from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForTokenClassification, AutoConfig, Qwen2VLForConditionalGeneration, AutoProcessor from peft import PeftModel import torch import cohere from openai import OpenAI from together import Together import anthropic import replicate # import google.generativeai as genai import vertexai from vertexai.generative_models import GenerativeModel, Part, SafetySetting, FinishReason from mistralai import Mistral from qwen_vl_utils import process_vision_info import src.backend.util as util import src.envs as envs litellm.set_verbose=True # Set up basic configuration for logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Load spacy model for word tokenization nlp = spacy.load("en_core_web_sm") os.environ["HUGGINGFACE_API_KEY"] = envs.TOKEN class ModelLoadingException(Exception): """Exception raised for errors in loading a model. Attributes: model_id (str): The model identifier. revision (str): The model revision. """ def __init__(self, model_id, revision, messages="Error initializing model"): self.model_id = model_id self.revision = revision super().__init__(f"{messages} id={model_id} revision={revision}") class SummaryGenerator: """A class to generate summaries using a causal language model. Attributes: model (str): huggingface/{model_id} api_base (str): https://api-inference.huggingface.co/models/{model_id} summaries_df (DataFrame): DataFrame to store generated summaries. revision (str): Model revision. avg_length (float): Average length of summaries. answer_rate (float): Rate of non-empty summaries. """ def __init__(self, model_id, revision, device): """ Initializes the SummaryGenerator with a model. Args: model_id (str): Identifier for the model. revision (str): Revision of the model. """ self.model_id = model_id self.model = f"huggingface/{model_id}" self.api_base = f"https://api-inference.huggingface.co/models/{model_id}" self.summaries_df = pd.DataFrame() self.revision = revision self.device = device self.avg_length = None self.answer_rate = None self.exceptions = None self.local_model = None self.local_pipeline = None def generate_summaries(self, df, save_path=None): """Generate summaries for a given DataFrame of source docs. Args: df (DataFrame): DataFrame containing source docs. Returns: summaries_df (DataFrame): Generated summaries by the model. """ exceptions = [] if (save_path is not None) and os.path.exists(save_path): self.summaries_df = pd.read_csv(save_path) print(f'Loaded generated summaries from {save_path}') else: source, summary, dataset = [], [], [] print(f"Total: {df.shape[0]}") for index, row in tqdm(df.iterrows(), total=df.shape[0]): _source = row['text'] _dataset = row['dataset'] system_prompt = envs.SYSTEM_PROMPT user_prompt = f"{envs.USER_PROMPT}\nPassage:\n{_source}" _summary = None while not _summary: try: _summary = self.generate_summary(system_prompt, user_prompt) # print(f"Finish index {index}") break except Exception as e: if 'Rate limit reached' in str(e): wait_time = 300 current_time = datetime.now().strftime('%H:%M:%S') print(f"Rate limit hit at {current_time}. Waiting for 5 minutes before retrying...") time.sleep(wait_time) elif 'is currently loading' in str(e): wait_time = 200 print(f"Model is loading, wait for {wait_time}") time.sleep(wait_time) elif '429' in str(e): # for gemini models wait_time = 60 print(f"Quota has reached, wait for {wait_time}") time.sleep(wait_time) else: print(f"Error at index {index}: {e}") _summary = "" exceptions.append(index) break summary.append(_summary) source.append(_source) dataset.append(_dataset) # Sleep to prevent hitting rate limits too frequently time.sleep(1) self.summaries_df = pd.DataFrame(list(zip(source, summary, dataset)), columns=["source", "summary", "dataset"]) if save_path is not None: print(f'Save summaries to {save_path}') fpath = Path(save_path) fpath.parent.mkdir(parents=True, exist_ok=True) self.summaries_df.to_csv(fpath) self.exceptions = exceptions self._compute_avg_length() self._compute_answer_rate() return self.summaries_df def generate_summary(self, system_prompt: str, user_prompt: str): # Using Together AI API using_together_api = False together_ai_api_models = ['mixtral', 'dbrx', 'wizardlm', 'llama-3-', 'qwen2-72b-instruct', 'zero-one-ai', 'llama-3.2-'] #, 'mistralai' using_replicate_api = False replicate_api_models = ['snowflake', 'llama-3.1-405b'] using_pipeline = False pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo', 'llama-3.3'] for replicate_api_model in replicate_api_models: if replicate_api_model in self.model_id.lower(): using_replicate_api = True break if not using_replicate_api: for together_ai_api_model in together_ai_api_models: if together_ai_api_model in self.model_id.lower(): using_together_api = True break if not using_replicate_api and not using_together_api: for pipeline_model in pipeline_models: if pipeline_model in self.model_id.lower(): using_pipeline = True break # if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API if using_together_api: print('using together api') client = Together(api_key=os.environ.get('TOGETHER_API_KEY')) if 'llama-3.2-90b-vision' in self.model_id.lower() or 'llama-3.2-11b-vision' in self.model_id.lower(): messages = [ {"role": "system","content": system_prompt}, {"role": "user","content": [{"type": "text","text": user_prompt}]} ] else: messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] response = client.chat.completions.create( model=self.model_id, messages = messages, max_tokens=250, temperature=0, ) # print(response) result = response.choices[0].message.content print(result) return result # Using OpenAI API elif 'openai' in self.model_id.lower(): client = OpenAI() response = client.chat.completions.create( model=self.model_id.replace('openai/',''), messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] if 'gpt' in self.model_id else [{"role": "user", "content": system_prompt + '\n' + user_prompt}], temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models # max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models ) # print(response) result = response.choices[0].message.content print(result) return result # Using Grok API elif 'grok' in self.model_id.lower(): # xai XAI_API_KEY = os.getenv("XAI_API_KEY") client = OpenAI( api_key=XAI_API_KEY, base_url="https://api.x.ai/v1", ) completion = client.chat.completions.create( model=self.model_id.split('/')[-1], messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.0 ) result = completion.choices[0].message.content print(result) return result # Using Vertex AI API for Gemini models elif 'gemini' in self.model_id.lower(): vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1") model = GenerativeModel( self.model_id.lower().split('google/')[-1], system_instruction = [system_prompt] ) generation_config = { "temperature": 0, "max_output_tokens": 500 } safety_settings = [ SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE ) ] response = model.generate_content( user_prompt, safety_settings=safety_settings, generation_config=generation_config ) result = response.text print(result) return result # Using Replicate API elif using_replicate_api: print("using replicate") if 'snowflake' in self.model_id.lower(): input = { "prompt": user_prompt, "temperature": 0, "max_new_tokens": 250, "stop_sequences": "<|im_end|>", "prompt_template": f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + "<|im_start|>user\n{prompt}<|im_end|>\n\n<|im_start|>assistant\n", } else: input = { "prompt": user_prompt, "system_prompt": system_prompt, "temperature": 0, "max_new_tokens": 250 } response = replicate.run( self.model_id, input=input ) # print(response) if isinstance(response, list): response = ''.join(response) # print(response) # print() print(response) return response # Using Anthropic API for Claude models elif 'claude' in self.model_id.lower(): # using anthropic api print('using Anthropic API') client = anthropic.Anthropic() message = client.messages.create( model=self.model_id.split('/')[-1], max_tokens=1024, temperature=0, system=system_prompt, messages=[ { "role": "user", # "content": [ # { # "type": "text", # "text": user_prompt # } # ] "content": user_prompt } ] ) result = message.content[0].text print(result) return result # Using Cohere API elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower(): co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN')) response = co.chat( model=self.model_id.split('/')[-1], messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], temperature=0, ) result = response.message.content[0].text print(result) return result # Using MistralAI API elif 'mistral-large' in self.model_id.lower(): api_key = os.environ["MISTRAL_API_KEY"] client = Mistral(api_key=api_key) messages = [ { "role":"system", "content":system_prompt }, { "role":"user", "content":user_prompt } ] # No streaming chat_response = client.chat.complete( model=self.model_id, messages=messages, ) result = chat_response.choices[0].message.content print(result) return result # Using Deepseek API elif 'deepseek' in self.model_id.lower(): client = OpenAI(api_key=os.getenv("DeepSeek_API_KEY"), base_url="https://api.deepseek.com") response = client.chat.completions.create( model=self.model_id.split('/')[-1], messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], max_tokens=250, temperature=0, stream=False ) result = response.choices[0].message.content print(result) return result # Using HF pipeline or local checkpoints elif self.local_model is None and self.local_pipeline is None: if using_pipeline: self.local_pipeline = pipeline( "text-generation", model=self.model_id, tokenizer=AutoTokenizer.from_pretrained(self.model_id), torch_dtype=torch.bfloat16 if 'llama-3.2' in self.model_id.lower() or 'llama-3.3' in self.model_id.lower() else "auto", device_map="auto", trust_remote_code=True ) else: if 'ragamuffin' in self.model_id.lower(): self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id)) else: self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True) print("Tokenizer loaded") if 'jamba' in self.model_id.lower(): self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", use_mamba_kernels=False) elif 'qwen2-vl' in self.model_id.lower(): self.local_model = Qwen2VLForConditionalGeneration.from_pretrained( self.model_id, torch_dtype="auto", device_map="auto" ) self.processor = AutoProcessor.from_pretrained(self.model_id) # elif 'ragamuffin' in self.model_id.lower(): # print('Using ragamuffin') # self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id), # torch_dtype=torch.bfloat16, # forcing bfloat16 for now # attn_implementation="flash_attention_2") elif 'olmo' in self.model_id.lower(): self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id)#torch_dtype="auto" elif 'qwq-' in self.model_id.lower(): self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype="auto", device_map="auto") else: self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto" # print(self.local_model.device) print("Local model loaded") # Using local model/pipeline if self.local_pipeline: print('Using Transformers pipeline') messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] outputs = self.local_pipeline( messages, max_new_tokens=256, # return_full_text=False, do_sample=False ) result = outputs[0]["generated_text"][-1]['content'] print(result) return result elif self.local_model: # cannot call API. using local model / pipeline print('Using local model') # Set appropriate prompt based on model document if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower(): messages=[ # gemma-1.1, mistral-7b does not accept system role {"role": "user", "content": system_prompt + '\n' + user_prompt} ] prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) elif 'phi-2' in self.model_id.lower(): prompt = system_prompt + '\n' + user_prompt elif 'intel' in self.model_id.lower(): prompt = f"### System:\n{system_prompt}\n### User:\n{user_prompt}\n### Assistant:\n" elif 'qwen2-vl' in self.model_id.lower(): messages = [ { "role": "system", "content": [ {"type": "text", "text": system_prompt} ] }, { "role": "user", "content": [ {"type": "text", "text": user_prompt}, ], } ] else: messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) # Tokenize inputs if 'olmo' in self.model_id.lower(): input_ids = self.tokenizer([prompt], return_tensors='pt', return_token_type_ids=False)#.to(self.device) elif 'qwq' in self.model_id.lower(): input_ids = self.tokenizer([prompt], return_tensors="pt").to(self.device) else: input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) # Generate outputs if 'granite' in self.model_id.lower(): self.local_model.eval() outputs = self.local_model.generate(**input_ids, max_new_tokens=250) elif 'olmo' in self.model_id.lower(): outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01)#top_k=50, top_p=0.95) elif 'qwq' in self.model_id.lower(): outputs = self.local_model.generate(**input_ids, max_new_tokens=512, do_sample=True, temperature=0.01) else: with torch.no_grad(): outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower(): outputs = outputs[:, input_ids['input_ids'].shape[1]:] elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower() or 'qwq-' in self.model_id.lower(): outputs = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs) ] # Decode outputs if 'qwen2-vl' in self.model_id.lower(): result = self.processor.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] elif 'olmo' in self.model_id.lower() or 'qwq' in self.model_id.lower(): result = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] else: result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if 'gemma-2' in self.model_id.lower(): result = result.split(user_prompt + '\nmodel')[-1].strip() elif 'intel' in self.model_id.lower(): result = result.split("### Assistant:\n")[-1] elif 'jamba' in self.model_id.lower(): result = result.split(messages[-1]['content'])[1].strip() elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower(): pass elif 'olmo' in self.model_id.lower(): result = result.split("<|assistant|>\n")[-1] else: result = result.replace(prompt.strip(), '') print(result) return result def _compute_avg_length(self): """ Compute the average length of non-empty summaries using SpaCy. """ total_word_count = 0 total_count = 0 for summary in self.summaries_df['summary']: if util.is_summary_valid(summary): doc = nlp(summary) words = [token.text for token in doc if token.is_alpha] total_word_count += len(words) total_count += 1 self.avg_length = 0 if total_count == 0 else total_word_count / total_count def _compute_answer_rate(self): """ Compute the rate of non-empty summaries. """ valid_count = sum(1 for summary in self.summaries_df['summary'] if util.is_summary_valid(summary)) total_count = len(self.summaries_df) self.answer_rate = 0 if total_count == 0 else valid_count / total_count class EvaluationModel: """A class to evaluate generated summaries. Attributes: model (CrossEncoder): The evaluation model. scores (list): List of evaluation scores. accuracy (float): Accuracy of the summaries. hallucination_rate (float): Rate of hallucination in summaries. """ def __init__(self, model_path, device): """ Initializes the EvaluationModel with a CrossEncoder model. Args: model_path (str): Path to the CrossEncoder model. """ config = AutoConfig.from_pretrained('google/flan-t5-large') self.model = AutoModelForTokenClassification.from_pretrained(model_path, config=config) self.device = device self.model.to(self.device) self.scores = [] self.factual_consistency_rate = None self.hallucination_rate = None def predict(self, text_pairs): """Load LoRA adapters of HHEM and make predictions All HHEM 2.1 settings, e.g., prompt template, are hardcoded in this function. Args: text_pairs: list of tuples, each tuple contains two strings (premise, hypothesis) checkpoint: model ID on Hugging Face """ prompt = " Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}" tokenizer = AutoTokenizer.from_pretrained('t5-base') inputs = tokenizer( [prompt.format(text1=pair[0], text2=pair[1]) for pair in text_pairs], return_tensors='pt', padding='longest').to(self.device) self.model.eval() with torch.no_grad(): output = self.model(**inputs) logits = output.logits logits = logits[:,0,:] # get the logits on the first token logits = torch.softmax(logits, dim=-1) scores = [round(x, 5) for x in logits[:, 1].tolist()] # list of float return scores def evaluate_hallucination(self, summaries_df): """ Evaluate the hallucination rate in summaries. Updates the 'scores' attribute of the instance with the computed scores. Args: summaries_df (DataFrame): DataFrame containing source docs and summaries. Returns: list: List of hallucination scores. Also updates the 'scores' attribute of the instance. """ hem_scores = [] sources = [] summaries = [] source_summary_pairs = util.create_pairs(summaries_df) for doc, summary in source_summary_pairs: if util.is_summary_valid(summary): try: summary = util.normalize_summary(summary) score = self.predict([(doc, summary)])[0] hem_scores.append(score) sources.append(doc) summaries.append(summary) if score < 0.5: print(score) print(doc) print('-'*20) print(summary) print('='*50) except Exception as e: logging.error(f"Error while running HEM: {e}") raise self.scores = hem_scores eval_results = {'source': sources, 'summary': summaries, 'HEM scores': hem_scores} return hem_scores, eval_results def compute_factual_consistency_rate(self, threshold=0.5): """ Compute the factual consistency rate of the evaluated summaries based on the previously calculated scores. This method relies on the 'scores' attribute being populated, typically via the 'evaluate_hallucination' method. Returns: float: Factual Consistency Rate. Also updates the 'factual_consistency_rate' and 'hallucination_rate' attributes of the instance. Raises: ValueError: If scores have not been calculated prior to calling this method. """ if not self.scores: error_msg = "Scores not calculated. Call evaluate_hallucination() first." logging.error(error_msg) raise ValueError(error_msg) # Use threshold of 0.5 to compute factual_consistency_rate num_above_threshold = sum(score >= threshold for score in self.scores) num_total = len(self.scores) if not num_total: raise ValueError("No scores available to compute factual consistency rate.") self.factual_consistency_rate = (num_above_threshold / num_total) * 100 self.hallucination_rate = 100 - self.factual_consistency_rate return self.factual_consistency_rate