|
import argparse |
|
import os |
|
import sys |
|
import time |
|
import uvicorn |
|
import requests |
|
import asyncio |
|
import logging |
|
|
|
from pathlib import Path |
|
from fastapi import FastAPI, Depends, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
from pydantic import BaseModel, Field |
|
from typing import Union, List, Dict, Any |
|
from sse_starlette.sse import EventSourceResponse, ServerSentEvent |
|
from utils.logger import logger |
|
from networks.message_streamer import MessageStreamer |
|
from messagers.message_composer import MessageComposer |
|
from mocks.stream_chat_mocker import stream_chat_mock |
|
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
class EmbeddingResponseItem(BaseModel): |
|
object: str = "embedding" |
|
index: int |
|
embedding: List[List[float]] |
|
|
|
class EmbeddingResponse(BaseModel): |
|
object: str = "list" |
|
data: List[EmbeddingResponseItem] |
|
model: str |
|
usage: Dict[str, Any] |
|
|
|
|
|
class ChatAPIApp: |
|
def __init__(self): |
|
self.app = FastAPI( |
|
docs_url="/", |
|
title="HuggingFace LLM API", |
|
swagger_ui_parameters={"defaultModelsExpandDepth": -1}, |
|
version="1.0", |
|
) |
|
self.setup_routes() |
|
|
|
self.app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
def get_available_models(self): |
|
|
|
|
|
current_time = int(time.time()) |
|
self.available_models = { |
|
"object": "list", |
|
"data": [ |
|
{ |
|
"id": "mixtral-8x7b", |
|
"description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co./mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
"object": "model", |
|
"created": current_time, |
|
"owned_by": "mistralai", |
|
}, |
|
{ |
|
"id": "mistral-7b", |
|
"description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co./mistralai/Mistral-7B-Instruct-v0.2", |
|
"object": "model", |
|
"created": current_time, |
|
"owned_by": "mistralai", |
|
}, |
|
{ |
|
"id": "nous-mixtral-8x7b", |
|
"description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co./NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", |
|
"object": "model", |
|
"created": current_time, |
|
"owned_by": "NousResearch", |
|
}, |
|
{ |
|
"id": "gemma-7b", |
|
"description": "[google/gemma-7b-it]: https://huggingface.co./google/gemma-7b-it", |
|
"object": "model", |
|
"created": current_time, |
|
"owned_by": "Google", |
|
}, |
|
{ |
|
"id": "codellama-7b", |
|
"description": "[codellama/CodeLlama-7b-hf]: https://huggingface.co./codellama/CodeLlama-7b-hf", |
|
"object": "model", |
|
"created": current_time, |
|
"owned_by": "codellama", |
|
}, |
|
{ |
|
"id": "bert-base-uncased", |
|
"description": "[google-bert/bert-base-uncased]: https://huggingface.co./google-bert/bert-base-uncased", |
|
"object": "embedding", |
|
"created": current_time, |
|
"owned_by": "google", |
|
}, |
|
], |
|
} |
|
return self.available_models |
|
|
|
def extract_api_key( |
|
credentials: HTTPAuthorizationCredentials = Depends( |
|
HTTPBearer(auto_error=False) |
|
), |
|
): |
|
api_key = None |
|
if credentials: |
|
api_key = credentials.credentials |
|
else: |
|
api_key = os.getenv("HF_TOKEN") |
|
|
|
if api_key: |
|
if api_key.startswith("hf_"): |
|
return api_key |
|
else: |
|
logger.warn(f"Invalid HF Token!") |
|
else: |
|
logger.warn("Not provide HF Token!") |
|
return None |
|
|
|
class QueryRequest(BaseModel): |
|
input: str |
|
model: str = Field(default="bert-base-uncased") |
|
encoding_format: str |
|
|
|
class ChatCompletionsPostItem(BaseModel): |
|
model: str = Field( |
|
default="mixtral-8x7b", |
|
description="(str) `mixtral-8x7b`", |
|
) |
|
messages: list = Field( |
|
default=[{"role": "user", "content": "Hello, who are you?"}], |
|
description="(list) Messages", |
|
) |
|
temperature: Union[float, None] = Field( |
|
default=0.5, |
|
description="(float) Temperature", |
|
) |
|
top_p: Union[float, None] = Field( |
|
default=0.95, |
|
description="(float) top p", |
|
) |
|
max_tokens: Union[int, None] = Field( |
|
default=-1, |
|
description="(int) Max tokens", |
|
) |
|
use_cache: bool = Field( |
|
default=False, |
|
description="(bool) Use cache", |
|
) |
|
stream: bool = Field( |
|
default=False, |
|
description="(bool) Stream", |
|
) |
|
|
|
def chat_completions( |
|
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key) |
|
): |
|
streamer = MessageStreamer(model=item.model) |
|
composer = MessageComposer(model=item.model) |
|
composer.merge(messages=item.messages) |
|
|
|
|
|
stream_response = streamer.chat_response( |
|
prompt=composer.merged_str, |
|
temperature=item.temperature, |
|
top_p=item.top_p, |
|
max_new_tokens=item.max_tokens, |
|
api_key=api_key, |
|
use_cache=item.use_cache, |
|
) |
|
if item.stream: |
|
event_source_response = EventSourceResponse( |
|
streamer.chat_return_generator(stream_response), |
|
media_type="text/event-stream", |
|
ping=2000, |
|
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), |
|
) |
|
return event_source_response |
|
else: |
|
data_response = streamer.chat_return_dict(stream_response) |
|
return data_response |
|
|
|
async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)): |
|
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{request.model}" |
|
headers = {"Authorization": f"Bearer {api_key}"} |
|
|
|
response = await requests.post(api_url, headers=headers, json={"inputs": request.input}) |
|
result = response.json() |
|
|
|
if "error" in result: |
|
logging.error(f"Error from Hugging Face API: {result['error']}") |
|
error_detail = result.get('error', 'No detailed error message provided.') |
|
raise HTTPException(status_code=503, detail=f"The model is currently loading, please re-run the query. Detail: {error_detail}") |
|
|
|
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list): |
|
flattened_embeddings = [item for sublist in result for item in sublist] |
|
data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(flattened_embeddings)] |
|
return EmbeddingResponse( |
|
object="list", |
|
data=data, |
|
model=request.model, |
|
usage={"prompt_tokens": len(request.input), "total_tokens": len(request.input)} |
|
) |
|
else: |
|
logging.error(f"Unexpected response format: {result}") |
|
raise HTTPException(status_code=500, detail="Unexpected response format.") |
|
|
|
def setup_routes(self): |
|
for prefix in ["", "/v1", "/api", "/api/v1"]: |
|
if prefix in ["/api/v1"]: |
|
include_in_schema = True |
|
else: |
|
include_in_schema = False |
|
|
|
self.app.get( |
|
prefix + "/models", |
|
summary="Get available models", |
|
include_in_schema=include_in_schema, |
|
)(self.get_available_models) |
|
|
|
self.app.post( |
|
prefix + "/chat/completions", |
|
summary="Chat completions in conversation session", |
|
include_in_schema=include_in_schema, |
|
)(self.chat_completions) |
|
|
|
self.app.post( |
|
prefix + "/embeddings", |
|
summary="Generate embeddings for the given texts", |
|
include_in_schema=include_in_schema, |
|
response_model=EmbeddingResponse |
|
)(self.embedding) |
|
|
|
|
|
class ArgParser(argparse.ArgumentParser): |
|
def __init__(self, *args, **kwargs): |
|
super(ArgParser, self).__init__(*args, **kwargs) |
|
|
|
self.add_argument( |
|
"-s", |
|
"--server", |
|
type=str, |
|
default="0.0.0.0", |
|
help="Server IP for HF LLM Chat API", |
|
) |
|
self.add_argument( |
|
"-p", |
|
"--port", |
|
type=int, |
|
default=23333, |
|
help="Server Port for HF LLM Chat API", |
|
) |
|
|
|
self.add_argument( |
|
"-d", |
|
"--dev", |
|
default=False, |
|
action="store_true", |
|
help="Run in dev mode", |
|
) |
|
|
|
self.args = self.parse_args(sys.argv[1:]) |
|
|
|
|
|
app = ChatAPIApp().app |
|
|
|
if __name__ == "__main__": |
|
args = ArgParser().args |
|
if args.dev: |
|
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) |
|
else: |
|
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) |
|
|
|
|
|
|