mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 08:44:59 +00:00
feat: Add local cache for OpenRouter models
This commit is contained in:
parent
d8fbd9cbd3
commit
5052150e2e
2 changed files with 140 additions and 0 deletions
|
@ -14,6 +14,7 @@ from typing import Optional, Union
|
||||||
import json5
|
import json5
|
||||||
import yaml
|
import yaml
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from aider.openrouter import OpenRouterModelManager
|
||||||
|
|
||||||
from aider.dump import dump # noqa: F401
|
from aider.dump import dump # noqa: F401
|
||||||
from aider.llm import litellm
|
from aider.llm import litellm
|
||||||
|
@ -149,8 +150,13 @@ class ModelInfoManager:
|
||||||
self.verify_ssl = True
|
self.verify_ssl = True
|
||||||
self._cache_loaded = False
|
self._cache_loaded = False
|
||||||
|
|
||||||
|
# Manager for the cached OpenRouter model database
|
||||||
|
self.openrouter_manager = OpenRouterModelManager()
|
||||||
|
|
||||||
def set_verify_ssl(self, verify_ssl):
|
def set_verify_ssl(self, verify_ssl):
|
||||||
self.verify_ssl = verify_ssl
|
self.verify_ssl = verify_ssl
|
||||||
|
if hasattr(self, "openrouter_manager"):
|
||||||
|
self.openrouter_manager.set_verify_ssl(verify_ssl)
|
||||||
|
|
||||||
def _load_cache(self):
|
def _load_cache(self):
|
||||||
if self._cache_loaded:
|
if self._cache_loaded:
|
||||||
|
@ -232,6 +238,12 @@ class ModelInfoManager:
|
||||||
return litellm_info
|
return litellm_info
|
||||||
|
|
||||||
if not cached_info and model.startswith("openrouter/"):
|
if not cached_info and model.startswith("openrouter/"):
|
||||||
|
# First try using the locally cached OpenRouter model database
|
||||||
|
openrouter_info = self.openrouter_manager.get_model_info(model)
|
||||||
|
if openrouter_info:
|
||||||
|
return openrouter_info
|
||||||
|
|
||||||
|
# Fallback to legacy web-scraping if the API cache does not contain the model
|
||||||
openrouter_info = self.fetch_openrouter_model_info(model)
|
openrouter_info = self.fetch_openrouter_model_info(model)
|
||||||
if openrouter_info:
|
if openrouter_info:
|
||||||
return openrouter_info
|
return openrouter_info
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
"""
|
||||||
|
OpenRouter model metadata caching and lookup.
|
||||||
|
|
||||||
|
This module keeps a local cached copy of the OpenRouter model list
|
||||||
|
(downloaded from ``https://openrouter.ai/api/v1/models``) and exposes a
|
||||||
|
helper class that returns metadata for a given model in a format compatible
|
||||||
|
with litellm’s ``get_model_info``.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def _cost_per_token(val: str | None) -> float | None:
|
||||||
|
"""Convert a per-million price string to a per-token float."""
|
||||||
|
if val in (None, "", "0"):
|
||||||
|
return 0.0 if val == "0" else None
|
||||||
|
try:
|
||||||
|
return float(val) / 1_000_000
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterModelManager:
|
||||||
|
MODELS_URL = "https://openrouter.ai/api/v1/models"
|
||||||
|
CACHE_TTL = 60 * 60 * 24 # 24 h
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.cache_dir = Path.home() / ".aider" / "caches"
|
||||||
|
self.cache_file = self.cache_dir / "openrouter_models.json"
|
||||||
|
self.content: Dict | None = None
|
||||||
|
self.verify_ssl: bool = True
|
||||||
|
self._cache_loaded = False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# Public API #
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
def set_verify_ssl(self, verify_ssl: bool) -> None:
|
||||||
|
"""Enable/disable SSL verification for API requests."""
|
||||||
|
self.verify_ssl = verify_ssl
|
||||||
|
|
||||||
|
def get_model_info(self, model: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Return metadata for *model* or an empty ``dict`` when unknown.
|
||||||
|
|
||||||
|
``model`` should use the aider naming convention, e.g.
|
||||||
|
``openrouter/nousresearch/deephermes-3-mistral-24b-preview:free``.
|
||||||
|
"""
|
||||||
|
self._ensure_content()
|
||||||
|
if not self.content or "data" not in self.content:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
route = self._strip_prefix(model)
|
||||||
|
|
||||||
|
# Consider both the exact id and id without any “:suffix”.
|
||||||
|
candidates = {route}
|
||||||
|
if ":" in route:
|
||||||
|
candidates.add(route.split(":", 1)[0])
|
||||||
|
|
||||||
|
record = next((item for item in self.content["data"] if item.get("id") in candidates), None)
|
||||||
|
if not record:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
context_len = (
|
||||||
|
record.get("top_provider", {}).get("context_length")
|
||||||
|
or record.get("context_length")
|
||||||
|
or None
|
||||||
|
)
|
||||||
|
|
||||||
|
pricing = record.get("pricing", {})
|
||||||
|
return {
|
||||||
|
"max_input_tokens": context_len,
|
||||||
|
"max_tokens": context_len,
|
||||||
|
"max_output_tokens": context_len,
|
||||||
|
"input_cost_per_token": _cost_per_token(pricing.get("prompt")),
|
||||||
|
"output_cost_per_token": _cost_per_token(pricing.get("completion")),
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# Internal helpers #
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
def _strip_prefix(self, model: str) -> str:
|
||||||
|
return model[len("openrouter/") :] if model.startswith("openrouter/") else model
|
||||||
|
|
||||||
|
def _ensure_content(self) -> None:
|
||||||
|
self._load_cache()
|
||||||
|
if not self.content:
|
||||||
|
self._update_cache()
|
||||||
|
|
||||||
|
def _load_cache(self) -> None:
|
||||||
|
if self._cache_loaded:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if self.cache_file.exists():
|
||||||
|
cache_age = time.time() - self.cache_file.stat().st_mtime
|
||||||
|
if cache_age < self.CACHE_TTL:
|
||||||
|
try:
|
||||||
|
self.content = json.loads(self.cache_file.read_text())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.content = None
|
||||||
|
except OSError:
|
||||||
|
# Cache directory might be unwritable; ignore.
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._cache_loaded = True
|
||||||
|
|
||||||
|
def _update_cache(self) -> None:
|
||||||
|
try:
|
||||||
|
response = requests.get(self.MODELS_URL, timeout=10, verify=self.verify_ssl)
|
||||||
|
if response.status_code == 200:
|
||||||
|
self.content = response.json()
|
||||||
|
try:
|
||||||
|
self.cache_file.write_text(json.dumps(self.content, indent=2))
|
||||||
|
except OSError:
|
||||||
|
pass # Non-fatal if we can’t write the cache
|
||||||
|
except Exception as ex: # noqa: BLE001
|
||||||
|
print(f"Failed to fetch OpenRouter model list: {ex}")
|
||||||
|
try:
|
||||||
|
self.cache_file.write_text("{}")
|
||||||
|
except OSError:
|
||||||
|
pass
|
Loading…
Add table
Add a link
Reference in a new issue