diff --git a/aider/models.py b/aider/models.py index 191f99873..99aacc363 100644 --- a/aider/models.py +++ b/aider/models.py @@ -632,78 +632,72 @@ MODEL_SETTINGS = [ ] -model_info_url = ( - "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" -) +class ModelInfoManager: + MODEL_INFO_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" + CACHE_TTL = 60 * 60 * 24 # 24 hours -#ai refactor get_model_flexible & get_model_info into a class! -# the class should load the cache_file once, on __init__ -def get_model_flexible(model, content): - info = content.get(model, dict()) - if info: - return info + def __init__(self): + self.cache_dir = Path.home() / ".aider" / "caches" + self.cache_file = self.cache_dir / "model_prices_and_context_window.json" + self.content = None + self._load_cache() - pieces = model.split("/") - if len(pieces) == 2: - info = content.get(pieces[1]) - if info and info.get("litellm_provider") == pieces[0]: + def _load_cache(self): + try: + self.cache_dir.mkdir(parents=True, exist_ok=True) + if self.cache_file.exists(): + cache_age = time.time() - self.cache_file.stat().st_mtime + if cache_age < self.CACHE_TTL: + self.content = json.loads(self.cache_file.read_text()) + except OSError: + pass + + def _update_cache(self): + if not litellm._lazy_module: + try: + import requests + response = requests.get(self.MODEL_INFO_URL, timeout=5) + if response.status_code == 200: + self.content = response.json() + try: + self.cache_file.write_text(json.dumps(self.content, indent=4)) + except OSError: + pass + except Exception as ex: + print(str(ex)) + + def get_model_flexible(self, model): + if not self.content: + self._update_cache() + + if not self.content: + return dict() + + info = self.content.get(model, dict()) + if info: return info - return dict() + pieces = model.split("/") + if len(pieces) == 2: + info = self.content.get(pieces[1]) + if info and info.get("litellm_provider") == pieces[0]: + return info - -def get_model_info(model): - if not litellm._lazy_module: - cache_dir = Path.home() / ".aider" / "caches" - cache_file = cache_dir / "model_prices_and_context_window.json" - - try: - cache_dir.mkdir(parents=True, exist_ok=True) - use_cache = True - except OSError: - # If we can't create the cache directory, we'll skip using the cache - use_cache = False - - if use_cache: - current_time = time.time() - cache_age = ( - current_time - cache_file.stat().st_mtime if cache_file.exists() else float("inf") - ) - - if cache_age < 60 * 60 * 24: - try: - content = json.loads(cache_file.read_text()) - res = get_model_flexible(model, content) - if res: - return res - except Exception as ex: - print(str(ex)) - - import requests - - try: - response = requests.get(model_info_url, timeout=5) - if response.status_code == 200: - content = response.json() - if use_cache: - try: - cache_file.write_text(json.dumps(content, indent=4)) - except OSError: - # If we can't write to the cache file, we'll just skip caching - pass - res = get_model_flexible(model, content) - if res: - return res - except Exception as ex: - print(str(ex)) - - # If all else fails, do it the slow way... - try: - info = litellm.get_model_info(model) - return info - except Exception: return dict() + def get_model_info(self, model): + info = self.get_model_flexible(model) + if info: + return info + + # If all else fails, do it the slow way... + try: + return litellm.get_model_info(model) + except Exception: + return dict() + +model_info_manager = ModelInfoManager() + class Model(ModelSettings): def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format=None): @@ -737,7 +731,7 @@ class Model(ModelSettings): self.get_editor_model(editor_model, editor_edit_format) def get_model_info(self, model): - return get_model_info(model) + return model_info_manager.get_model_info(model) def configure_model_settings(self, model): for ms in MODEL_SETTINGS: