refactor: Extract get_model_info into a standalone function

This commit is contained in:
Paul Gauthier (aider) 2024-08-25 08:19:15 -07:00
parent 539a657624
commit 5473d99e13

View file

@ -427,6 +427,65 @@ MODEL_SETTINGS = [
]
def get_model_info(model):
if litellm._lazy_module:
# Do it the slow way...
try:
return litellm.get_model_info(model)
except Exception:
return dict()
cache_dir = Path.home() / ".aider" / "caches"
cache_file = cache_dir / "model_prices_and_context_window.json"
cache_dir.mkdir(parents=True, exist_ok=True)
current_time = time.time()
cache_age = (
current_time - cache_file.stat().st_mtime if cache_file.exists() else float("inf")
)
if cache_file.exists() and cache_age < 86400: # 86400 seconds = 1 day
content = safe_read_json(cache_file)
if content:
info = content.get(model)
if info:
return info
# If cache doesn't exist or is old, fetch from GitHub
try:
import requests
url = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/"
"model_prices_and_context_window.json"
)
response = requests.get(url, timeout=5)
if response.status_code == 200:
content = response.json()
safe_write_json(cache_file, content)
info = content.get(model)
if info:
return info
except Exception:
# If fetching from GitHub fails, fall back to local resource
try:
with importlib.resources.open_text(
"litellm", "model_prices_and_context_window_backup.json"
) as f:
content = json.load(f)
info = content.get(model)
if info:
return info
except Exception:
pass # If there's any error, fall back to the slow way
# If all else fails, do it the slow way...
try:
return litellm.get_model_info(model)
except Exception:
return dict()
class Model(ModelSettings):
def __init__(self, model, weak_model=None):
self.name = model
@ -454,62 +513,7 @@ class Model(ModelSettings):
self.get_weak_model(weak_model)
def get_model_info(self, model):
if litellm._lazy_module:
# Do it the slow way...
try:
return litellm.get_model_info(model)
except Exception:
return dict()
cache_dir = Path.home() / ".aider" / "caches"
cache_file = cache_dir / "model_prices_and_context_window.json"
cache_dir.mkdir(parents=True, exist_ok=True)
current_time = time.time()
cache_age = (
current_time - cache_file.stat().st_mtime if cache_file.exists() else float("inf")
)
if cache_file.exists() and cache_age < 86400: # 86400 seconds = 1 day
content = safe_read_json(cache_file)
if content:
info = content.get(model)
if info:
return info
# If cache doesn't exist or is old, fetch from GitHub
try:
import requests
url = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/"
"model_prices_and_context_window.json"
)
response = requests.get(url, timeout=5)
if response.status_code == 200:
content = response.json()
safe_write_json(cache_file, content)
info = content.get(model)
if info:
return info
except Exception:
# If fetching from GitHub fails, fall back to local resource
try:
with importlib.resources.open_text(
"litellm", "model_prices_and_context_window_backup.json"
) as f:
content = json.load(f)
info = content.get(model)
if info:
return info
except Exception:
pass # If there's any error, fall back to the slow way
# If all else fails, do it the slow way...
try:
return litellm.get_model_info(model)
except Exception:
return dict()
return get_model_info(model)
def configure_model_settings(self, model):
for ms in MODEL_SETTINGS: