diff --git a/aider/models.py b/aider/models.py index 816e816f7..f158b840a 100644 --- a/aider/models.py +++ b/aider/models.py @@ -150,17 +150,17 @@ class Model: def __init__(self, model, weak_model=None): self.name = model - # Are all needed keys/params available? - res = validate_environment(model) - self.missing_keys = res.get("missing_keys") - self.keys_in_environment = res.get("keys_in_environment") - # Do we have the model_info? try: self.info = litellm.get_model_info(model) except Exception: self.info = dict() + # Are all needed keys/params available? + res = self.validate_environment() + self.missing_keys = res.get("missing_keys") + self.keys_in_environment = res.get("keys_in_environment") + if self.info.get("max_input_tokens", 0) < 32 * 1024: self.max_chat_history_tokens = 1024 else: @@ -284,24 +284,25 @@ class Model: with Image.open(fname) as img: return img.size + def validate_environment(self): + # https://github.com/BerriAI/litellm/issues/3190 -def validate_environment(model): - # https://github.com/BerriAI/litellm/issues/3190 + model = self.name + res = litellm.validate_environment(model) + if res["keys_in_environment"]: + return res + if res["missing_keys"]: + return res + + provider = self.info.get("litellm_provider").lower() + if provider == "cohere_chat": + return validate_variables(["COHERE_API_KEY"]) + if provider == "gemini": + return validate_variables(["GEMINI_API_KEY"]) + if provider == "groq": + return validate_variables(["GROQ_API_KEY"]) - res = litellm.validate_environment(model) - if res["keys_in_environment"]: return res - if res["missing_keys"]: - return res - - if model.startswith("command-r"): - return validate_variables(["COHERE_API_KEY"]) - if model.startswith("gemini"): - return validate_variables(["GEMINI_API_KEY"]) - if model.startswith("groq/"): - return validate_variables(["GROQ_API_KEY"]) - - return res def validate_variables(vars): @@ -359,11 +360,14 @@ def sanity_check_model(io, model): def fuzzy_match_models(name): + name = name.lower() + chat_models = [] for model, attrs in litellm.model_cost.items(): + model = model.lower() if attrs.get("mode") != "chat": continue - provider = attrs["litellm_provider"] + "/" + provider = (attrs["litellm_provider"] + "/").lower() if model.startswith(provider): fq_model = model