From 2fc358a02f067e202691f4b5435f41479926b429 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Wed, 3 Jul 2024 13:04:13 -0300 Subject: [PATCH] Defer litellm import until first chat message; only import streamlit if gui is activated --- aider/main.py | 3 ++- aider/models.py | 44 +++++++++++++++++++++++--------------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/aider/main.py b/aider/main.py index a7a6ce67a..6a099f0b4 100644 --- a/aider/main.py +++ b/aider/main.py @@ -8,7 +8,6 @@ import git import httpx from dotenv import load_dotenv from prompt_toolkit.enums import EditingMode -from streamlit.web import cli from aider import __version__, models, utils from aider.args import get_parser @@ -150,6 +149,8 @@ def scrub_sensitive_info(args, text): def launch_gui(args): + from streamlit.web import cli + from aider import gui print() diff --git a/aider/models.py b/aider/models.py index 41dd3e6b7..e9c1fb625 100644 --- a/aider/models.py +++ b/aider/models.py @@ -1,9 +1,11 @@ import difflib +import importlib import json import math import os import sys from dataclasses import dataclass, fields +from pathlib import Path from typing import Optional import yaml @@ -40,7 +42,7 @@ gpt-3.5-turbo-16k gpt-3.5-turbo-16k-0613 """ -OPENAI_MODELS = [ln.strip for ln in OPENAI_MODELS.splitlines() if ln.strip()] +OPENAI_MODELS = [ln.strip() for ln in OPENAI_MODELS.splitlines() if ln.strip()] ANTHROPIC_MODELS = """ claude-2 @@ -51,7 +53,7 @@ claude-3-sonnet-20240229 claude-3-5-sonnet-20240620 """ -ANTHROPIC_MODELS = [ln.strip for ln in ANTHROPIC_MODELS.splitlines() if ln.strip()] +ANTHROPIC_MODELS = [ln.strip() for ln in ANTHROPIC_MODELS.splitlines() if ln.strip()] @dataclass @@ -365,25 +367,7 @@ class Model: def __init__(self, model, weak_model=None): self.name = model - # Do we have the model_info? - try: - self.info = litellm.get_model_info(model) - except Exception: - self.info = dict() - - if not self.info and "gpt-4o" in self.name: - self.info = { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 5e-06, - "output_cost_per_token": 1.5e-5, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": True, - "supports_parallel_function_calling": True, - "supports_vision": True, - } + self.info = self.get_model_info(model) # Are all needed keys/params available? res = self.validate_environment() @@ -404,6 +388,24 @@ class Model: else: self.get_weak_model(weak_model) + def get_model_info(self, model): + # Try and do this quickly, without triggering the litellm import + spec = importlib.util.find_spec("litellm") + if spec: + origin = Path(spec.origin) + fname = origin.parent / "model_prices_and_context_window_backup.json" + if fname.exists(): + data = json.loads(fname.read_text()) + info = data.get(model) + if info: + return info + + # Do it the slow way... + try: + self.info = litellm.get_model_info(model) + except Exception: + self.info = dict() + def configure_model_settings(self, model): for ms in MODEL_SETTINGS: # direct match, or match "provider/"