"""Groq API integration with streaming and optimizations.""" import os import logging import asyncio from typing import Dict, Any, Optional, List, AsyncGenerator, Union import groq from datetime import datetime import json from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor from .base import ReasoningStrategy, StrategyResult logger = logging.getLogger(__name__) @dataclass class GroqConfig: """Configuration for Groq models.""" model_name: str max_tokens: int temperature: float top_p: float top_k: Optional[int] = None presence_penalty: float = 0.0 frequency_penalty: float = 0.0 stop_sequences: Optional[List[str]] = None chunk_size: int = 1024 retry_attempts: int = 3 retry_delay: float = 1.0 class GroqStrategy(ReasoningStrategy): """Enhanced reasoning strategy using Groq's API with streaming and optimizations.""" def __init__(self, api_key: Optional[str] = None): """Initialize Groq strategy.""" super().__init__() self.api_key = api_key or os.getenv("GROQ_API_KEY") if not self.api_key: raise ValueError("GROQ_API_KEY must be set") # Initialize Groq client with optimized settings self.client = groq.Groq( api_key=self.api_key, timeout=30, max_retries=3 ) # Optimized model configurations self.model_configs = { "mixtral": GroqConfig( model_name="mixtral-8x7b-32768", max_tokens=32768, temperature=0.7, top_p=0.9, top_k=40, presence_penalty=0.1, frequency_penalty=0.1, chunk_size=4096 ), "llama": GroqConfig( model_name="llama2-70b-4096", max_tokens=4096, temperature=0.8, top_p=0.9, top_k=50, presence_penalty=0.2, frequency_penalty=0.2, chunk_size=1024 ) } # Initialize thread pool for parallel processing self.executor = ThreadPoolExecutor(max_workers=4) # Response cache self.cache: Dict[str, Any] = {} self.cache_ttl = 3600 # 1 hour async def reason_stream( self, query: str, context: Dict[str, Any], model: str = "mixtral", chunk_handler: Optional[callable] = None ) -> AsyncGenerator[str, None]: """ Stream reasoning results from Groq's API. Args: query: The query to reason about context: Additional context model: Model to use ('mixtral' or 'llama') chunk_handler: Optional callback for handling chunks """ config = self.model_configs[model] messages = self._prepare_messages(query, context) try: stream = await self.client.chat.completions.create( model=config.model_name, messages=messages, temperature=config.temperature, top_p=config.top_p, top_k=config.top_k, presence_penalty=config.presence_penalty, frequency_penalty=config.frequency_penalty, max_tokens=config.max_tokens, stream=True ) collected_content = [] async for chunk in stream: if chunk.choices[0].delta.content: content = chunk.choices[0].delta.content collected_content.append(content) if chunk_handler: await chunk_handler(content) yield content # Cache the complete response cache_key = self._generate_cache_key(query, context, model) self.cache[cache_key] = { "content": "".join(collected_content), "timestamp": datetime.now() } except Exception as e: logger.error(f"Groq streaming error: {str(e)}") yield f"Error: {str(e)}" async def reason( self, query: str, context: Dict[str, Any], model: str = "mixtral" ) -> StrategyResult: """ Enhanced reasoning with Groq's API including optimizations. Args: query: The query to reason about context: Additional context model: Model to use ('mixtral' or 'llama') """ # Check cache first cache_key = self._generate_cache_key(query, context, model) cached_response = self._get_from_cache(cache_key) if cached_response: return self._create_result(cached_response, model, from_cache=True) config = self.model_configs[model] messages = self._prepare_messages(query, context) # Implement retry logic with exponential backoff for attempt in range(config.retry_attempts): try: start_time = datetime.now() # Make API call with optimized parameters response = await self.client.chat.completions.create( model=config.model_name, messages=messages, temperature=config.temperature, top_p=config.top_p, top_k=config.top_k, presence_penalty=config.presence_penalty, frequency_penalty=config.frequency_penalty, max_tokens=config.max_tokens, stream=False ) end_time = datetime.now() # Cache successful response self.cache[cache_key] = { "content": response.choices[0].message.content, "timestamp": datetime.now() } return self._create_result(response, model) except Exception as e: delay = config.retry_delay * (2 ** attempt) logger.warning(f"Groq API attempt {attempt + 1} failed: {str(e)}") if attempt < config.retry_attempts - 1: await asyncio.sleep(delay) else: logger.error(f"All Groq API attempts failed: {str(e)}") return self._create_error_result(str(e)) def _create_result( self, response: Union[Dict, Any], model: str, from_cache: bool = False ) -> StrategyResult: """Create a strategy result from response.""" if from_cache: answer = response["content"] confidence = 0.9 # Higher confidence for cached responses performance_metrics = { "from_cache": True, "cache_age": (datetime.now() - response["timestamp"]).total_seconds() } else: answer = response.choices[0].message.content confidence = self._calculate_confidence(response) performance_metrics = { "latency": response.usage.total_tokens / 1000, # tokens per second "tokens_used": response.usage.total_tokens, "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "model": self.model_configs[model].model_name } return StrategyResult( strategy_type="groq", success=True, answer=answer, confidence=confidence, reasoning_trace=[{ "step": "groq_api_call", "model": self.model_configs[model].model_name, "timestamp": datetime.now().isoformat(), "metrics": performance_metrics }], metadata={ "model": self.model_configs[model].model_name, "from_cache": from_cache }, performance_metrics=performance_metrics ) def _create_error_result(self, error: str) -> StrategyResult: """Create an error result.""" return StrategyResult( strategy_type="groq", success=False, answer=None, confidence=0.0, reasoning_trace=[{ "step": "groq_api_error", "error": error, "timestamp": datetime.now().isoformat() }], metadata={"error": error}, performance_metrics={} ) def _generate_cache_key( self, query: str, context: Dict[str, Any], model: str ) -> str: """Generate a cache key.""" key_data = { "query": query, "context": context, "model": model } return json.dumps(key_data, sort_keys=True) def _get_from_cache(self, cache_key: str) -> Optional[Dict]: """Get response from cache if valid.""" if cache_key in self.cache: cached = self.cache[cache_key] age = (datetime.now() - cached["timestamp"]).total_seconds() if age < self.cache_ttl: return cached else: del self.cache[cache_key] return None def _calculate_confidence(self, response: Any) -> float: """Calculate confidence score from response.""" confidence = 0.8 # Base confidence # Adjust based on token usage and model behavior if hasattr(response, 'usage'): completion_tokens = response.usage.completion_tokens total_tokens = response.usage.total_tokens # Length-based adjustment if completion_tokens < 10: confidence *= 0.8 # Reduce confidence for very short responses elif completion_tokens > 100: confidence *= 1.1 # Increase confidence for detailed responses # Token efficiency adjustment token_efficiency = completion_tokens / total_tokens if token_efficiency > 0.5: confidence *= 1.1 # Good token efficiency # Response completeness check if hasattr(response.choices[0], 'finish_reason'): if response.choices[0].finish_reason == "stop": confidence *= 1.1 # Natural completion elif response.choices[0].finish_reason == "length": confidence *= 0.9 # Truncated response return min(1.0, max(0.0, confidence)) # Ensure between 0 and 1 def _prepare_messages( self, query: str, context: Dict[str, Any] ) -> List[Dict[str, str]]: """Prepare messages for the Groq API.""" messages = [] # Add system message if provided if "system_message" in context: messages.append({ "role": "system", "content": context["system_message"] }) # Add chat history if provided if "chat_history" in context: messages.extend(context["chat_history"]) # Add the current query messages.append({ "role": "user", "content": query }) return messages