asasasText / app.py
Hjgugugjhuhjggg's picture
Update app.py
50c545e verified
from llama_cpp import Llama
from concurrent.futures import ThreadPoolExecutor, as_completed
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import os
from dotenv import load_dotenv
from pydantic import BaseModel
import requests
import traceback
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
global_data = {
'models': {},
'tokens': {
'eos': 'eos_token',
'pad': 'pad_token',
'padding': 'padding_token',
'unk': 'unk_token',
'bos': 'bos_token',
'sep': 'sep_token',
'cls': 'cls_token',
'mask': 'mask_token'
}
}
model_configs = [
{"repo_id": "Hjgugugjhuhjggg/mergekit-ties-tzamfyy-Q2_K-GGUF", "filename": "mergekit-ties-tzamfyy-q2_k.gguf", "name": "my_model"}
]
models = {}
def load_model(model_config):
model_name = model_config['name']
if model_name not in models:
try:
model = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
models[model_name] = model
global_data['models'] = models
return model
except Exception as e:
print(f"Error loading model {model_name}: {e}")
traceback.print_exc()
models[model_name] = None
return None
for config in model_configs:
load_model(config)
class ChatRequest(BaseModel):
message: str
max_tokens_per_part: int = 256
def normalize_input(input_text):
return input_text.strip()
def remove_duplicates(text):
lines = text.split('\n')
unique_lines = []
seen_lines = set()
for line in lines:
line = line.strip()
if line and line not in seen_lines:
unique_lines.append(line)
seen_lines.add(line)
return '\n'.join(unique_lines)
def generate_model_response(model, inputs, max_tokens_per_part):
try:
if model is None:
return []
full_response = ""
responses = []
response = model(inputs, max_tokens=max_tokens_per_part, stop=["\n\n"])
if 'choices' not in response or len(response['choices']) == 0 or 'text' not in response['choices'][0]:
return [f"Error: Invalid model response format"]
text = response['choices'][0]['text']
if text:
responses.append(remove_duplicates(text))
return responses
except Exception as e:
print(f"Error generating response: {e}")
traceback.print_exc()
return [f"Error: {e}"]
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/generate")
async def generate(request: ChatRequest):
inputs = normalize_input(request.message)
with ThreadPoolExecutor() as executor:
futures = [executor.submit(generate_model_response, model, inputs, request.max_tokens_per_part) for model in models.values()]
responses = [{'model': model_name, 'response': future.result()} for model_name, future in zip(models.keys(), as_completed(futures))]
unique_responses = {}
for response_set in responses:
model_name = response_set['model']
if model_name not in unique_responses:
unique_responses[model_name] = []
unique_responses[model_name].extend(response_set['response'])
formatted_response = ""
for model, response_parts in unique_responses.items():
formatted_response += f"**{model}:**\n"
for i, part in enumerate(response_parts):
formatted_response += f"Part {i+1}:\n{part}\n\n"
return {"response": formatted_response}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)