Use provider in the validate_environment workaround

This commit is contained in:
Paul Gauthier 2024-04-23 06:28:37 -07:00
parent 89a7b3470a
commit c21118ce5c

View file

@ -150,17 +150,17 @@ class Model:
def __init__(self, model, weak_model=None):
self.name = model
# Are all needed keys/params available?
res = validate_environment(model)
self.missing_keys = res.get("missing_keys")
self.keys_in_environment = res.get("keys_in_environment")
# Do we have the model_info?
try:
self.info = litellm.get_model_info(model)
except Exception:
self.info = dict()
# Are all needed keys/params available?
res = self.validate_environment()
self.missing_keys = res.get("missing_keys")
self.keys_in_environment = res.get("keys_in_environment")
if self.info.get("max_input_tokens", 0) < 32 * 1024:
self.max_chat_history_tokens = 1024
else:
@ -284,24 +284,25 @@ class Model:
with Image.open(fname) as img:
return img.size
def validate_environment(self):
# https://github.com/BerriAI/litellm/issues/3190
def validate_environment(model):
# https://github.com/BerriAI/litellm/issues/3190
model = self.name
res = litellm.validate_environment(model)
if res["keys_in_environment"]:
return res
if res["missing_keys"]:
return res
provider = self.info.get("litellm_provider").lower()
if provider == "cohere_chat":
return validate_variables(["COHERE_API_KEY"])
if provider == "gemini":
return validate_variables(["GEMINI_API_KEY"])
if provider == "groq":
return validate_variables(["GROQ_API_KEY"])
res = litellm.validate_environment(model)
if res["keys_in_environment"]:
return res
if res["missing_keys"]:
return res
if model.startswith("command-r"):
return validate_variables(["COHERE_API_KEY"])
if model.startswith("gemini"):
return validate_variables(["GEMINI_API_KEY"])
if model.startswith("groq/"):
return validate_variables(["GROQ_API_KEY"])
return res
def validate_variables(vars):
@ -359,11 +360,14 @@ def sanity_check_model(io, model):
def fuzzy_match_models(name):
name = name.lower()
chat_models = []
for model, attrs in litellm.model_cost.items():
model = model.lower()
if attrs.get("mode") != "chat":
continue
provider = attrs["litellm_provider"] + "/"
provider = (attrs["litellm_provider"] + "/").lower()
if model.startswith(provider):
fq_model = model