DABstep / baseline /custom_litellm.py
martinigoyanes's picture
initial commit
034ac91
from typing import Optional
from smolagents import LiteLLMModel
from tenacity import retry, stop_after_attempt, before_sleep_log, retry_if_exception_type, wait_exponential, wait_random
import litellm
import logging
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
class LiteLLMModelWithBackOff(LiteLLMModel):
def __init__(self, max_tokens: Optional[int] = 1500, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_tokens = max_tokens
@retry(
stop=stop_after_attempt(450),
wait=wait_exponential(min=1, max=120, exp_base=2, multiplier=1) + wait_random(0, 5),
before_sleep=before_sleep_log(logger, logging.WARNING),
retry=retry_if_exception_type((
litellm.Timeout,
litellm.RateLimitError,
litellm.APIConnectionError,
litellm.InternalServerError
))
)
def __call__(self, *args, **kwargs):
return super().__call__(max_tokens=self.max_tokens, *args, **kwargs)