Spaces:
Runtime error
Runtime error
"""OpenAI-compatible API endpoints.""" | |
from typing import Dict, List, Optional, Union | |
from pydantic import BaseModel, Field | |
from fastapi import APIRouter, HTTPException, Depends | |
import time | |
import json | |
import asyncio | |
from datetime import datetime | |
class ChatMessage(BaseModel): | |
"""OpenAI-compatible chat message.""" | |
role: str = Field(..., description="The role of the message author (system/user/assistant)") | |
content: str = Field(..., description="The content of the message") | |
name: Optional[str] = Field(None, description="The name of the author") | |
class ChatCompletionRequest(BaseModel): | |
"""OpenAI-compatible chat completion request.""" | |
model: str = Field(..., description="Model to use") | |
messages: List[ChatMessage] | |
temperature: Optional[float] = Field(0.7, description="Sampling temperature") | |
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter") | |
n: Optional[int] = Field(1, description="Number of completions") | |
stream: Optional[bool] = Field(False, description="Whether to stream responses") | |
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences") | |
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate") | |
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty") | |
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty") | |
user: Optional[str] = Field(None, description="User identifier") | |
class ChatCompletionResponse(BaseModel): | |
"""OpenAI-compatible chat completion response.""" | |
id: str = Field(..., description="Unique identifier for the completion") | |
object: str = Field("chat.completion", description="Object type") | |
created: int = Field(..., description="Unix timestamp of creation") | |
model: str = Field(..., description="Model used") | |
choices: List[Dict] = Field(..., description="Completion choices") | |
usage: Dict[str, int] = Field(..., description="Token usage statistics") | |
class OpenAICompatibleAPI: | |
"""OpenAI-compatible API implementation.""" | |
def __init__(self, reasoning_engine): | |
self.reasoning_engine = reasoning_engine | |
self.router = APIRouter() | |
self.setup_routes() | |
def setup_routes(self): | |
"""Setup API routes.""" | |
async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse: | |
try: | |
# Convert chat history to context | |
context = self._prepare_context(request.messages) | |
# Get the last user message | |
user_message = next( | |
(msg.content for msg in reversed(request.messages) | |
if msg.role == "user"), | |
None | |
) | |
if not user_message: | |
raise HTTPException(status_code=400, detail="No user message found") | |
# Process with reasoning engine | |
result = await self.reasoning_engine.reason( | |
query=user_message, | |
context={ | |
"chat_history": context, | |
"temperature": request.temperature, | |
"top_p": request.top_p, | |
"max_tokens": request.max_tokens, | |
"stream": request.stream | |
} | |
) | |
# Format response | |
response = { | |
"id": f"chatcmpl-{int(time.time()*1000)}", | |
"object": "chat.completion", | |
"created": int(time.time()), | |
"model": request.model, | |
"choices": [{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": result.answer | |
}, | |
"finish_reason": "stop" | |
}], | |
"usage": { | |
"prompt_tokens": self._estimate_tokens(user_message), | |
"completion_tokens": self._estimate_tokens(result.answer), | |
"total_tokens": self._estimate_tokens(user_message) + | |
self._estimate_tokens(result.answer) | |
} | |
} | |
return ChatCompletionResponse(**response) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def list_models(): | |
"""List available models.""" | |
return { | |
"object": "list", | |
"data": [ | |
{ | |
"id": "venture-gpt-1", | |
"object": "model", | |
"created": int(time.time()), | |
"owned_by": "venture-ai", | |
"permission": [], | |
"root": "venture-gpt-1", | |
"parent": None | |
} | |
] | |
} | |
def _prepare_context(self, messages: List[ChatMessage]) -> List[Dict]: | |
"""Convert messages to context format.""" | |
return [ | |
{ | |
"role": msg.role, | |
"content": msg.content, | |
"name": msg.name, | |
"timestamp": datetime.now().isoformat() | |
} | |
for msg in messages | |
] | |
def _estimate_tokens(self, text: str) -> int: | |
"""Estimate token count for a text.""" | |
# Simple estimation: ~4 characters per token | |
return len(text) // 4 | |