mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-31 01:35:00 +00:00
Use provider in the validate_environment workaround
This commit is contained in:
parent
89a7b3470a
commit
c21118ce5c
1 changed files with 25 additions and 21 deletions
|
@ -150,17 +150,17 @@ class Model:
|
||||||
def __init__(self, model, weak_model=None):
|
def __init__(self, model, weak_model=None):
|
||||||
self.name = model
|
self.name = model
|
||||||
|
|
||||||
# Are all needed keys/params available?
|
|
||||||
res = validate_environment(model)
|
|
||||||
self.missing_keys = res.get("missing_keys")
|
|
||||||
self.keys_in_environment = res.get("keys_in_environment")
|
|
||||||
|
|
||||||
# Do we have the model_info?
|
# Do we have the model_info?
|
||||||
try:
|
try:
|
||||||
self.info = litellm.get_model_info(model)
|
self.info = litellm.get_model_info(model)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.info = dict()
|
self.info = dict()
|
||||||
|
|
||||||
|
# Are all needed keys/params available?
|
||||||
|
res = self.validate_environment()
|
||||||
|
self.missing_keys = res.get("missing_keys")
|
||||||
|
self.keys_in_environment = res.get("keys_in_environment")
|
||||||
|
|
||||||
if self.info.get("max_input_tokens", 0) < 32 * 1024:
|
if self.info.get("max_input_tokens", 0) < 32 * 1024:
|
||||||
self.max_chat_history_tokens = 1024
|
self.max_chat_history_tokens = 1024
|
||||||
else:
|
else:
|
||||||
|
@ -284,24 +284,25 @@ class Model:
|
||||||
with Image.open(fname) as img:
|
with Image.open(fname) as img:
|
||||||
return img.size
|
return img.size
|
||||||
|
|
||||||
|
def validate_environment(self):
|
||||||
|
# https://github.com/BerriAI/litellm/issues/3190
|
||||||
|
|
||||||
def validate_environment(model):
|
model = self.name
|
||||||
# https://github.com/BerriAI/litellm/issues/3190
|
res = litellm.validate_environment(model)
|
||||||
|
if res["keys_in_environment"]:
|
||||||
|
return res
|
||||||
|
if res["missing_keys"]:
|
||||||
|
return res
|
||||||
|
|
||||||
|
provider = self.info.get("litellm_provider").lower()
|
||||||
|
if provider == "cohere_chat":
|
||||||
|
return validate_variables(["COHERE_API_KEY"])
|
||||||
|
if provider == "gemini":
|
||||||
|
return validate_variables(["GEMINI_API_KEY"])
|
||||||
|
if provider == "groq":
|
||||||
|
return validate_variables(["GROQ_API_KEY"])
|
||||||
|
|
||||||
res = litellm.validate_environment(model)
|
|
||||||
if res["keys_in_environment"]:
|
|
||||||
return res
|
return res
|
||||||
if res["missing_keys"]:
|
|
||||||
return res
|
|
||||||
|
|
||||||
if model.startswith("command-r"):
|
|
||||||
return validate_variables(["COHERE_API_KEY"])
|
|
||||||
if model.startswith("gemini"):
|
|
||||||
return validate_variables(["GEMINI_API_KEY"])
|
|
||||||
if model.startswith("groq/"):
|
|
||||||
return validate_variables(["GROQ_API_KEY"])
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def validate_variables(vars):
|
def validate_variables(vars):
|
||||||
|
@ -359,11 +360,14 @@ def sanity_check_model(io, model):
|
||||||
|
|
||||||
|
|
||||||
def fuzzy_match_models(name):
|
def fuzzy_match_models(name):
|
||||||
|
name = name.lower()
|
||||||
|
|
||||||
chat_models = []
|
chat_models = []
|
||||||
for model, attrs in litellm.model_cost.items():
|
for model, attrs in litellm.model_cost.items():
|
||||||
|
model = model.lower()
|
||||||
if attrs.get("mode") != "chat":
|
if attrs.get("mode") != "chat":
|
||||||
continue
|
continue
|
||||||
provider = attrs["litellm_provider"] + "/"
|
provider = (attrs["litellm_provider"] + "/").lower()
|
||||||
|
|
||||||
if model.startswith(provider):
|
if model.startswith(provider):
|
||||||
fq_model = model
|
fq_model = model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue