from typing import Optional from smolagents import LiteLLMModel from tenacity import retry, stop_after_attempt, before_sleep_log, retry_if_exception_type, wait_exponential, wait_random import litellm import logging logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) class LiteLLMModelWithBackOff(LiteLLMModel): def __init__(self, max_tokens: Optional[int] = 1500, *args, **kwargs): super().__init__(*args, **kwargs) self.max_tokens = max_tokens @retry( stop=stop_after_attempt(450), wait=wait_exponential(min=1, max=120, exp_base=2, multiplier=1) + wait_random(0, 5), before_sleep=before_sleep_log(logger, logging.WARNING), retry=retry_if_exception_type(( litellm.Timeout, litellm.RateLimitError, litellm.APIConnectionError, litellm.InternalServerError )) ) def __call__(self, *args, **kwargs): return super().__call__(max_tokens=self.max_tokens, *args, **kwargs)