diff --git a/aider/models.py b/aider/models.py index 11139167f..5284f7971 100644 --- a/aider/models.py +++ b/aider/models.py @@ -1,6 +1,7 @@ import difflib import json import math +import os import sys from dataclasses import dataclass, fields from typing import Optional @@ -150,7 +151,7 @@ class Model: self.name = model # Are all needed keys/params available? - res = litellm.validate_environment(model) + res = validate_environment(model) self.missing_keys = res.get("missing_keys") self.keys_in_environment = res.get("keys_in_environment") @@ -284,6 +285,35 @@ class Model: return img.size +def validate_environment(model): + # https://github.com/BerriAI/litellm/issues/3190 + + res = litellm.validate_environment(model) + if res["keys_in_environment"]: + return res + if res["missing_keys"]: + return res + + if model.startswith("command-r"): + return validate_variables(["COHERE_API_KEY"]) + if model.startswith("gemini"): + return validate_variables(["GEMINI_API_KEY"]) + if model.startswith("groq/"): + return validate_variables(["GROQ_API_KEY"]) + + return res + + +def validate_variables(vars): + missing = [] + for var in vars: + if var not in os.environ: + missing.append(var) + if missing: + return dict(keys_in_environment=False, missing_keys=missing) + return dict(keys_in_environment=True, missing_keys=missing) + + def sanity_check_models(io, main_model): missing_model_info = False if not sanity_check_model(io, main_model):