from typing import Dict, List, Any from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel, PeftConfig import torch import time class EndpointHandler: def __init__(self, path="luxmorocco/qiyas-falcon-7b"): # load the model config = PeftConfig.from_pretrained(path) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) self.model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, return_dict=True, load_in_4bit=True, device_map={"":0}, trust_remote_code=True, quantization_config=bnb_config, ) self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) self.tokenizer.pad_token = self.tokenizer.eos_token self.model = PeftModel.from_pretrained(self.model, path) def __call__(self, data: Any) -> Dict[str, Any]: """ Args: inputs :obj:`list`:. The object should be like {"context": "some word", "question": "some word"} containing: - "context": - "question": Return: A :obj:`list`:. The object returned should be like {"answer": "some word", time: "..."} containing: - "answer": answer the question based on the context - "time": the time run predict """ inputs = data.pop("inputs", data) context = inputs.pop("context", inputs) question = inputs.pop("question", inputs) prompt = f"""Answer the question based on the context below. If the question cannot be answered using the information provided answer with 'No answer'. Stop response if end. >>TITLE<<: Flawless answer. >>CONTEXT<<: {context} >>QUESTION<<: {question} >>ANSWER<<: """.strip() batch = self.tokenizer( prompt, padding=True, truncation=True, return_tensors='pt' ) batch = batch.to('cuda:0') generation_config = self.model.generation_config generation_config.top_p = 0.7 generation_config.temperature = 0.7 generation_config.max_new_tokens = 256 generation_config.num_return_sequences = 1 generation_config.pad_token_id = self.tokenizer.eos_token_id generation_config.eos_token_id = self.tokenizer.eos_token_id start = time.time() with torch.cuda.amp.autocast(): output_tokens = self.model.generate( input_ids = batch.input_ids, generation_config=generation_config, ) end = time.time() generated_text = self.tokenizer.decode(output_tokens[0]) prediction = {'answer': generated_text.split('>>END<<')[0].split('>>ANSWER<<:')[1].strip(), 'time': f"{(end-start):.2f} s"} return prediction