File size: 5,822 Bytes
1d75522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""OpenAI-compatible API endpoints."""

from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from fastapi import APIRouter, HTTPException, Depends
import time
import json
import asyncio
from datetime import datetime

class ChatMessage(BaseModel):
    """OpenAI-compatible chat message."""
    role: str = Field(..., description="The role of the message author (system/user/assistant)")
    content: str = Field(..., description="The content of the message")
    name: Optional[str] = Field(None, description="The name of the author")

class ChatCompletionRequest(BaseModel):
    """OpenAI-compatible chat completion request."""
    model: str = Field(..., description="Model to use")
    messages: List[ChatMessage]
    temperature: Optional[float] = Field(0.7, description="Sampling temperature")
    top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
    n: Optional[int] = Field(1, description="Number of completions")
    stream: Optional[bool] = Field(False, description="Whether to stream responses")
    stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
    max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
    presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
    frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
    user: Optional[str] = Field(None, description="User identifier")

class ChatCompletionResponse(BaseModel):
    """OpenAI-compatible chat completion response."""
    id: str = Field(..., description="Unique identifier for the completion")
    object: str = Field("chat.completion", description="Object type")
    created: int = Field(..., description="Unix timestamp of creation")
    model: str = Field(..., description="Model used")
    choices: List[Dict] = Field(..., description="Completion choices")
    usage: Dict[str, int] = Field(..., description="Token usage statistics")

class OpenAICompatibleAPI:
    """OpenAI-compatible API implementation."""
    
    def __init__(self, reasoning_engine):
        self.reasoning_engine = reasoning_engine
        self.router = APIRouter()
        self.setup_routes()
    
    def setup_routes(self):
        """Setup API routes."""
        
        @self.router.post("/v1/chat/completions")
        async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse:
            try:
                # Convert chat history to context
                context = self._prepare_context(request.messages)
                
                # Get the last user message
                user_message = next(
                    (msg.content for msg in reversed(request.messages) 
                     if msg.role == "user"),
                    None
                )
                
                if not user_message:
                    raise HTTPException(status_code=400, detail="No user message found")
                
                # Process with reasoning engine
                result = await self.reasoning_engine.reason(
                    query=user_message,
                    context={
                        "chat_history": context,
                        "temperature": request.temperature,
                        "top_p": request.top_p,
                        "max_tokens": request.max_tokens,
                        "stream": request.stream
                    }
                )
                
                # Format response
                response = {
                    "id": f"chatcmpl-{int(time.time()*1000)}",
                    "object": "chat.completion",
                    "created": int(time.time()),
                    "model": request.model,
                    "choices": [{
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": result.answer
                        },
                        "finish_reason": "stop"
                    }],
                    "usage": {
                        "prompt_tokens": self._estimate_tokens(user_message),
                        "completion_tokens": self._estimate_tokens(result.answer),
                        "total_tokens": self._estimate_tokens(user_message) + 
                                      self._estimate_tokens(result.answer)
                    }
                }
                
                return ChatCompletionResponse(**response)
            
            except Exception as e:
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.router.get("/v1/models")
        async def list_models():
            """List available models."""
            return {
                "object": "list",
                "data": [
                    {
                        "id": "venture-gpt-1",
                        "object": "model",
                        "created": int(time.time()),
                        "owned_by": "venture-ai",
                        "permission": [],
                        "root": "venture-gpt-1",
                        "parent": None
                    }
                ]
            }
    
    def _prepare_context(self, messages: List[ChatMessage]) -> List[Dict]:
        """Convert messages to context format."""
        return [
            {
                "role": msg.role,
                "content": msg.content,
                "name": msg.name,
                "timestamp": datetime.now().isoformat()
            }
            for msg in messages
        ]
    
    def _estimate_tokens(self, text: str) -> int:
        """Estimate token count for a text."""
        # Simple estimation: ~4 characters per token
        return len(text) // 4