Spaces:
Sleeping
Sleeping
from llama_cpp import Llama | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import uvicorn | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import os | |
from dotenv import load_dotenv | |
from pydantic import BaseModel | |
import requests | |
import traceback | |
load_dotenv() | |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
global_data = { | |
'models': {}, | |
'tokens': { | |
'eos': 'eos_token', | |
'pad': 'pad_token', | |
'padding': 'padding_token', | |
'unk': 'unk_token', | |
'bos': 'bos_token', | |
'sep': 'sep_token', | |
'cls': 'cls_token', | |
'mask': 'mask_token' | |
} | |
} | |
model_configs = [ | |
{"repo_id": "Hjgugugjhuhjggg/mergekit-ties-tzamfyy-Q2_K-GGUF", "filename": "mergekit-ties-tzamfyy-q2_k.gguf", "name": "my_model"} | |
] | |
models = {} | |
def load_model(model_config): | |
model_name = model_config['name'] | |
if model_name not in models: | |
try: | |
model = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN) | |
models[model_name] = model | |
global_data['models'] = models | |
return model | |
except Exception as e: | |
print(f"Error loading model {model_name}: {e}") | |
traceback.print_exc() | |
models[model_name] = None | |
return None | |
for config in model_configs: | |
load_model(config) | |
class ChatRequest(BaseModel): | |
message: str | |
max_tokens_per_part: int = 256 | |
def normalize_input(input_text): | |
return input_text.strip() | |
def remove_duplicates(text): | |
lines = text.split('\n') | |
unique_lines = [] | |
seen_lines = set() | |
for line in lines: | |
line = line.strip() | |
if line and line not in seen_lines: | |
unique_lines.append(line) | |
seen_lines.add(line) | |
return '\n'.join(unique_lines) | |
def generate_model_response(model, inputs, max_tokens_per_part): | |
try: | |
if model is None: | |
return [] | |
full_response = "" | |
responses = [] | |
response = model(inputs, max_tokens=max_tokens_per_part, stop=["\n\n"]) | |
if 'choices' not in response or len(response['choices']) == 0 or 'text' not in response['choices'][0]: | |
return [f"Error: Invalid model response format"] | |
text = response['choices'][0]['text'] | |
if text: | |
responses.append(remove_duplicates(text)) | |
return responses | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
traceback.print_exc() | |
return [f"Error: {e}"] | |
app = FastAPI() | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def generate(request: ChatRequest): | |
inputs = normalize_input(request.message) | |
with ThreadPoolExecutor() as executor: | |
futures = [executor.submit(generate_model_response, model, inputs, request.max_tokens_per_part) for model in models.values()] | |
responses = [{'model': model_name, 'response': future.result()} for model_name, future in zip(models.keys(), as_completed(futures))] | |
unique_responses = {} | |
for response_set in responses: | |
model_name = response_set['model'] | |
if model_name not in unique_responses: | |
unique_responses[model_name] = [] | |
unique_responses[model_name].extend(response_set['response']) | |
formatted_response = "" | |
for model, response_parts in unique_responses.items(): | |
formatted_response += f"**{model}:**\n" | |
for i, part in enumerate(response_parts): | |
formatted_response += f"Part {i+1}:\n{part}\n\n" | |
return {"response": formatted_response} | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |