Spaces:
Runtime error
Runtime error
"""Groq API integration with streaming and optimizations.""" | |
import os | |
import logging | |
import asyncio | |
from typing import Dict, Any, Optional, List, AsyncGenerator, Union | |
import groq | |
from datetime import datetime | |
import json | |
from dataclasses import dataclass | |
from concurrent.futures import ThreadPoolExecutor | |
from .base import ReasoningStrategy, StrategyResult | |
logger = logging.getLogger(__name__) | |
class GroqConfig: | |
"""Configuration for Groq models.""" | |
model_name: str | |
max_tokens: int | |
temperature: float | |
top_p: float | |
top_k: Optional[int] = None | |
presence_penalty: float = 0.0 | |
frequency_penalty: float = 0.0 | |
stop_sequences: Optional[List[str]] = None | |
chunk_size: int = 1024 | |
retry_attempts: int = 3 | |
retry_delay: float = 1.0 | |
class GroqStrategy(ReasoningStrategy): | |
"""Enhanced reasoning strategy using Groq's API with streaming and optimizations.""" | |
def __init__(self, api_key: Optional[str] = None): | |
"""Initialize Groq strategy.""" | |
super().__init__() | |
self.api_key = api_key or os.getenv("GROQ_API_KEY") | |
if not self.api_key: | |
raise ValueError("GROQ_API_KEY must be set") | |
# Initialize Groq client with optimized settings | |
self.client = groq.Groq( | |
api_key=self.api_key, | |
timeout=30, | |
max_retries=3 | |
) | |
# Optimized model configurations | |
self.model_configs = { | |
"mixtral": GroqConfig( | |
model_name="mixtral-8x7b-32768", | |
max_tokens=32768, | |
temperature=0.7, | |
top_p=0.9, | |
top_k=40, | |
presence_penalty=0.1, | |
frequency_penalty=0.1, | |
chunk_size=4096 | |
), | |
"llama": GroqConfig( | |
model_name="llama2-70b-4096", | |
max_tokens=4096, | |
temperature=0.8, | |
top_p=0.9, | |
top_k=50, | |
presence_penalty=0.2, | |
frequency_penalty=0.2, | |
chunk_size=1024 | |
) | |
} | |
# Initialize thread pool for parallel processing | |
self.executor = ThreadPoolExecutor(max_workers=4) | |
# Response cache | |
self.cache: Dict[str, Any] = {} | |
self.cache_ttl = 3600 # 1 hour | |
async def reason_stream( | |
self, | |
query: str, | |
context: Dict[str, Any], | |
model: str = "mixtral", | |
chunk_handler: Optional[callable] = None | |
) -> AsyncGenerator[str, None]: | |
""" | |
Stream reasoning results from Groq's API. | |
Args: | |
query: The query to reason about | |
context: Additional context | |
model: Model to use ('mixtral' or 'llama') | |
chunk_handler: Optional callback for handling chunks | |
""" | |
config = self.model_configs[model] | |
messages = self._prepare_messages(query, context) | |
try: | |
stream = await self.client.chat.completions.create( | |
model=config.model_name, | |
messages=messages, | |
temperature=config.temperature, | |
top_p=config.top_p, | |
top_k=config.top_k, | |
presence_penalty=config.presence_penalty, | |
frequency_penalty=config.frequency_penalty, | |
max_tokens=config.max_tokens, | |
stream=True | |
) | |
collected_content = [] | |
async for chunk in stream: | |
if chunk.choices[0].delta.content: | |
content = chunk.choices[0].delta.content | |
collected_content.append(content) | |
if chunk_handler: | |
await chunk_handler(content) | |
yield content | |
# Cache the complete response | |
cache_key = self._generate_cache_key(query, context, model) | |
self.cache[cache_key] = { | |
"content": "".join(collected_content), | |
"timestamp": datetime.now() | |
} | |
except Exception as e: | |
logger.error(f"Groq streaming error: {str(e)}") | |
yield f"Error: {str(e)}" | |
async def reason( | |
self, | |
query: str, | |
context: Dict[str, Any], | |
model: str = "mixtral" | |
) -> StrategyResult: | |
""" | |
Enhanced reasoning with Groq's API including optimizations. | |
Args: | |
query: The query to reason about | |
context: Additional context | |
model: Model to use ('mixtral' or 'llama') | |
""" | |
# Check cache first | |
cache_key = self._generate_cache_key(query, context, model) | |
cached_response = self._get_from_cache(cache_key) | |
if cached_response: | |
return self._create_result(cached_response, model, from_cache=True) | |
config = self.model_configs[model] | |
messages = self._prepare_messages(query, context) | |
# Implement retry logic with exponential backoff | |
for attempt in range(config.retry_attempts): | |
try: | |
start_time = datetime.now() | |
# Make API call with optimized parameters | |
response = await self.client.chat.completions.create( | |
model=config.model_name, | |
messages=messages, | |
temperature=config.temperature, | |
top_p=config.top_p, | |
top_k=config.top_k, | |
presence_penalty=config.presence_penalty, | |
frequency_penalty=config.frequency_penalty, | |
max_tokens=config.max_tokens, | |
stream=False | |
) | |
end_time = datetime.now() | |
# Cache successful response | |
self.cache[cache_key] = { | |
"content": response.choices[0].message.content, | |
"timestamp": datetime.now() | |
} | |
return self._create_result(response, model) | |
except Exception as e: | |
delay = config.retry_delay * (2 ** attempt) | |
logger.warning(f"Groq API attempt {attempt + 1} failed: {str(e)}") | |
if attempt < config.retry_attempts - 1: | |
await asyncio.sleep(delay) | |
else: | |
logger.error(f"All Groq API attempts failed: {str(e)}") | |
return self._create_error_result(str(e)) | |
def _create_result( | |
self, | |
response: Union[Dict, Any], | |
model: str, | |
from_cache: bool = False | |
) -> StrategyResult: | |
"""Create a strategy result from response.""" | |
if from_cache: | |
answer = response["content"] | |
confidence = 0.9 # Higher confidence for cached responses | |
performance_metrics = { | |
"from_cache": True, | |
"cache_age": (datetime.now() - response["timestamp"]).total_seconds() | |
} | |
else: | |
answer = response.choices[0].message.content | |
confidence = self._calculate_confidence(response) | |
performance_metrics = { | |
"latency": response.usage.total_tokens / 1000, # tokens per second | |
"tokens_used": response.usage.total_tokens, | |
"prompt_tokens": response.usage.prompt_tokens, | |
"completion_tokens": response.usage.completion_tokens, | |
"model": self.model_configs[model].model_name | |
} | |
return StrategyResult( | |
strategy_type="groq", | |
success=True, | |
answer=answer, | |
confidence=confidence, | |
reasoning_trace=[{ | |
"step": "groq_api_call", | |
"model": self.model_configs[model].model_name, | |
"timestamp": datetime.now().isoformat(), | |
"metrics": performance_metrics | |
}], | |
metadata={ | |
"model": self.model_configs[model].model_name, | |
"from_cache": from_cache | |
}, | |
performance_metrics=performance_metrics | |
) | |
def _create_error_result(self, error: str) -> StrategyResult: | |
"""Create an error result.""" | |
return StrategyResult( | |
strategy_type="groq", | |
success=False, | |
answer=None, | |
confidence=0.0, | |
reasoning_trace=[{ | |
"step": "groq_api_error", | |
"error": error, | |
"timestamp": datetime.now().isoformat() | |
}], | |
metadata={"error": error}, | |
performance_metrics={} | |
) | |
def _generate_cache_key( | |
self, | |
query: str, | |
context: Dict[str, Any], | |
model: str | |
) -> str: | |
"""Generate a cache key.""" | |
key_data = { | |
"query": query, | |
"context": context, | |
"model": model | |
} | |
return json.dumps(key_data, sort_keys=True) | |
def _get_from_cache(self, cache_key: str) -> Optional[Dict]: | |
"""Get response from cache if valid.""" | |
if cache_key in self.cache: | |
cached = self.cache[cache_key] | |
age = (datetime.now() - cached["timestamp"]).total_seconds() | |
if age < self.cache_ttl: | |
return cached | |
else: | |
del self.cache[cache_key] | |
return None | |
def _calculate_confidence(self, response: Any) -> float: | |
"""Calculate confidence score from response.""" | |
confidence = 0.8 # Base confidence | |
# Adjust based on token usage and model behavior | |
if hasattr(response, 'usage'): | |
completion_tokens = response.usage.completion_tokens | |
total_tokens = response.usage.total_tokens | |
# Length-based adjustment | |
if completion_tokens < 10: | |
confidence *= 0.8 # Reduce confidence for very short responses | |
elif completion_tokens > 100: | |
confidence *= 1.1 # Increase confidence for detailed responses | |
# Token efficiency adjustment | |
token_efficiency = completion_tokens / total_tokens | |
if token_efficiency > 0.5: | |
confidence *= 1.1 # Good token efficiency | |
# Response completeness check | |
if hasattr(response.choices[0], 'finish_reason'): | |
if response.choices[0].finish_reason == "stop": | |
confidence *= 1.1 # Natural completion | |
elif response.choices[0].finish_reason == "length": | |
confidence *= 0.9 # Truncated response | |
return min(1.0, max(0.0, confidence)) # Ensure between 0 and 1 | |
def _prepare_messages( | |
self, | |
query: str, | |
context: Dict[str, Any] | |
) -> List[Dict[str, str]]: | |
"""Prepare messages for the Groq API.""" | |
messages = [] | |
# Add system message if provided | |
if "system_message" in context: | |
messages.append({ | |
"role": "system", | |
"content": context["system_message"] | |
}) | |
# Add chat history if provided | |
if "chat_history" in context: | |
messages.extend(context["chat_history"]) | |
# Add the current query | |
messages.append({ | |
"role": "user", | |
"content": query | |
}) | |
return messages | |