agentic-system / reasoning /groq_strategy.py
Cascade Bot
Added Groq streaming support and optimizations - clean version
1d75522
"""Groq API integration with streaming and optimizations."""
import os
import logging
import asyncio
from typing import Dict, Any, Optional, List, AsyncGenerator, Union
import groq
from datetime import datetime
import json
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
from .base import ReasoningStrategy, StrategyResult
logger = logging.getLogger(__name__)
@dataclass
class GroqConfig:
"""Configuration for Groq models."""
model_name: str
max_tokens: int
temperature: float
top_p: float
top_k: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
stop_sequences: Optional[List[str]] = None
chunk_size: int = 1024
retry_attempts: int = 3
retry_delay: float = 1.0
class GroqStrategy(ReasoningStrategy):
"""Enhanced reasoning strategy using Groq's API with streaming and optimizations."""
def __init__(self, api_key: Optional[str] = None):
"""Initialize Groq strategy."""
super().__init__()
self.api_key = api_key or os.getenv("GROQ_API_KEY")
if not self.api_key:
raise ValueError("GROQ_API_KEY must be set")
# Initialize Groq client with optimized settings
self.client = groq.Groq(
api_key=self.api_key,
timeout=30,
max_retries=3
)
# Optimized model configurations
self.model_configs = {
"mixtral": GroqConfig(
model_name="mixtral-8x7b-32768",
max_tokens=32768,
temperature=0.7,
top_p=0.9,
top_k=40,
presence_penalty=0.1,
frequency_penalty=0.1,
chunk_size=4096
),
"llama": GroqConfig(
model_name="llama2-70b-4096",
max_tokens=4096,
temperature=0.8,
top_p=0.9,
top_k=50,
presence_penalty=0.2,
frequency_penalty=0.2,
chunk_size=1024
)
}
# Initialize thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=4)
# Response cache
self.cache: Dict[str, Any] = {}
self.cache_ttl = 3600 # 1 hour
async def reason_stream(
self,
query: str,
context: Dict[str, Any],
model: str = "mixtral",
chunk_handler: Optional[callable] = None
) -> AsyncGenerator[str, None]:
"""
Stream reasoning results from Groq's API.
Args:
query: The query to reason about
context: Additional context
model: Model to use ('mixtral' or 'llama')
chunk_handler: Optional callback for handling chunks
"""
config = self.model_configs[model]
messages = self._prepare_messages(query, context)
try:
stream = await self.client.chat.completions.create(
model=config.model_name,
messages=messages,
temperature=config.temperature,
top_p=config.top_p,
top_k=config.top_k,
presence_penalty=config.presence_penalty,
frequency_penalty=config.frequency_penalty,
max_tokens=config.max_tokens,
stream=True
)
collected_content = []
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
collected_content.append(content)
if chunk_handler:
await chunk_handler(content)
yield content
# Cache the complete response
cache_key = self._generate_cache_key(query, context, model)
self.cache[cache_key] = {
"content": "".join(collected_content),
"timestamp": datetime.now()
}
except Exception as e:
logger.error(f"Groq streaming error: {str(e)}")
yield f"Error: {str(e)}"
async def reason(
self,
query: str,
context: Dict[str, Any],
model: str = "mixtral"
) -> StrategyResult:
"""
Enhanced reasoning with Groq's API including optimizations.
Args:
query: The query to reason about
context: Additional context
model: Model to use ('mixtral' or 'llama')
"""
# Check cache first
cache_key = self._generate_cache_key(query, context, model)
cached_response = self._get_from_cache(cache_key)
if cached_response:
return self._create_result(cached_response, model, from_cache=True)
config = self.model_configs[model]
messages = self._prepare_messages(query, context)
# Implement retry logic with exponential backoff
for attempt in range(config.retry_attempts):
try:
start_time = datetime.now()
# Make API call with optimized parameters
response = await self.client.chat.completions.create(
model=config.model_name,
messages=messages,
temperature=config.temperature,
top_p=config.top_p,
top_k=config.top_k,
presence_penalty=config.presence_penalty,
frequency_penalty=config.frequency_penalty,
max_tokens=config.max_tokens,
stream=False
)
end_time = datetime.now()
# Cache successful response
self.cache[cache_key] = {
"content": response.choices[0].message.content,
"timestamp": datetime.now()
}
return self._create_result(response, model)
except Exception as e:
delay = config.retry_delay * (2 ** attempt)
logger.warning(f"Groq API attempt {attempt + 1} failed: {str(e)}")
if attempt < config.retry_attempts - 1:
await asyncio.sleep(delay)
else:
logger.error(f"All Groq API attempts failed: {str(e)}")
return self._create_error_result(str(e))
def _create_result(
self,
response: Union[Dict, Any],
model: str,
from_cache: bool = False
) -> StrategyResult:
"""Create a strategy result from response."""
if from_cache:
answer = response["content"]
confidence = 0.9 # Higher confidence for cached responses
performance_metrics = {
"from_cache": True,
"cache_age": (datetime.now() - response["timestamp"]).total_seconds()
}
else:
answer = response.choices[0].message.content
confidence = self._calculate_confidence(response)
performance_metrics = {
"latency": response.usage.total_tokens / 1000, # tokens per second
"tokens_used": response.usage.total_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"model": self.model_configs[model].model_name
}
return StrategyResult(
strategy_type="groq",
success=True,
answer=answer,
confidence=confidence,
reasoning_trace=[{
"step": "groq_api_call",
"model": self.model_configs[model].model_name,
"timestamp": datetime.now().isoformat(),
"metrics": performance_metrics
}],
metadata={
"model": self.model_configs[model].model_name,
"from_cache": from_cache
},
performance_metrics=performance_metrics
)
def _create_error_result(self, error: str) -> StrategyResult:
"""Create an error result."""
return StrategyResult(
strategy_type="groq",
success=False,
answer=None,
confidence=0.0,
reasoning_trace=[{
"step": "groq_api_error",
"error": error,
"timestamp": datetime.now().isoformat()
}],
metadata={"error": error},
performance_metrics={}
)
def _generate_cache_key(
self,
query: str,
context: Dict[str, Any],
model: str
) -> str:
"""Generate a cache key."""
key_data = {
"query": query,
"context": context,
"model": model
}
return json.dumps(key_data, sort_keys=True)
def _get_from_cache(self, cache_key: str) -> Optional[Dict]:
"""Get response from cache if valid."""
if cache_key in self.cache:
cached = self.cache[cache_key]
age = (datetime.now() - cached["timestamp"]).total_seconds()
if age < self.cache_ttl:
return cached
else:
del self.cache[cache_key]
return None
def _calculate_confidence(self, response: Any) -> float:
"""Calculate confidence score from response."""
confidence = 0.8 # Base confidence
# Adjust based on token usage and model behavior
if hasattr(response, 'usage'):
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens
# Length-based adjustment
if completion_tokens < 10:
confidence *= 0.8 # Reduce confidence for very short responses
elif completion_tokens > 100:
confidence *= 1.1 # Increase confidence for detailed responses
# Token efficiency adjustment
token_efficiency = completion_tokens / total_tokens
if token_efficiency > 0.5:
confidence *= 1.1 # Good token efficiency
# Response completeness check
if hasattr(response.choices[0], 'finish_reason'):
if response.choices[0].finish_reason == "stop":
confidence *= 1.1 # Natural completion
elif response.choices[0].finish_reason == "length":
confidence *= 0.9 # Truncated response
return min(1.0, max(0.0, confidence)) # Ensure between 0 and 1
def _prepare_messages(
self,
query: str,
context: Dict[str, Any]
) -> List[Dict[str, str]]:
"""Prepare messages for the Groq API."""
messages = []
# Add system message if provided
if "system_message" in context:
messages.append({
"role": "system",
"content": context["system_message"]
})
# Add chat history if provided
if "chat_history" in context:
messages.extend(context["chat_history"])
# Add the current query
messages.append({
"role": "user",
"content": query
})
return messages