mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-28 16:25:00 +00:00
Use provider in the validate_environment workaround
This commit is contained in:
parent
89a7b3470a
commit
c21118ce5c
1 changed files with 25 additions and 21 deletions
|
@ -150,17 +150,17 @@ class Model:
|
|||
def __init__(self, model, weak_model=None):
|
||||
self.name = model
|
||||
|
||||
# Are all needed keys/params available?
|
||||
res = validate_environment(model)
|
||||
self.missing_keys = res.get("missing_keys")
|
||||
self.keys_in_environment = res.get("keys_in_environment")
|
||||
|
||||
# Do we have the model_info?
|
||||
try:
|
||||
self.info = litellm.get_model_info(model)
|
||||
except Exception:
|
||||
self.info = dict()
|
||||
|
||||
# Are all needed keys/params available?
|
||||
res = self.validate_environment()
|
||||
self.missing_keys = res.get("missing_keys")
|
||||
self.keys_in_environment = res.get("keys_in_environment")
|
||||
|
||||
if self.info.get("max_input_tokens", 0) < 32 * 1024:
|
||||
self.max_chat_history_tokens = 1024
|
||||
else:
|
||||
|
@ -284,24 +284,25 @@ class Model:
|
|||
with Image.open(fname) as img:
|
||||
return img.size
|
||||
|
||||
def validate_environment(self):
|
||||
# https://github.com/BerriAI/litellm/issues/3190
|
||||
|
||||
def validate_environment(model):
|
||||
# https://github.com/BerriAI/litellm/issues/3190
|
||||
model = self.name
|
||||
res = litellm.validate_environment(model)
|
||||
if res["keys_in_environment"]:
|
||||
return res
|
||||
if res["missing_keys"]:
|
||||
return res
|
||||
|
||||
provider = self.info.get("litellm_provider").lower()
|
||||
if provider == "cohere_chat":
|
||||
return validate_variables(["COHERE_API_KEY"])
|
||||
if provider == "gemini":
|
||||
return validate_variables(["GEMINI_API_KEY"])
|
||||
if provider == "groq":
|
||||
return validate_variables(["GROQ_API_KEY"])
|
||||
|
||||
res = litellm.validate_environment(model)
|
||||
if res["keys_in_environment"]:
|
||||
return res
|
||||
if res["missing_keys"]:
|
||||
return res
|
||||
|
||||
if model.startswith("command-r"):
|
||||
return validate_variables(["COHERE_API_KEY"])
|
||||
if model.startswith("gemini"):
|
||||
return validate_variables(["GEMINI_API_KEY"])
|
||||
if model.startswith("groq/"):
|
||||
return validate_variables(["GROQ_API_KEY"])
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def validate_variables(vars):
|
||||
|
@ -359,11 +360,14 @@ def sanity_check_model(io, model):
|
|||
|
||||
|
||||
def fuzzy_match_models(name):
|
||||
name = name.lower()
|
||||
|
||||
chat_models = []
|
||||
for model, attrs in litellm.model_cost.items():
|
||||
model = model.lower()
|
||||
if attrs.get("mode") != "chat":
|
||||
continue
|
||||
provider = attrs["litellm_provider"] + "/"
|
||||
provider = (attrs["litellm_provider"] + "/").lower()
|
||||
|
||||
if model.startswith(provider):
|
||||
fq_model = model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue