Better unknown model warnings

This commit is contained in:
Paul Gauthier 2024-04-22 14:07:32 -07:00
parent f1ce673f78
commit efd3c39e50
4 changed files with 125 additions and 69 deletions

View file

@ -13,24 +13,6 @@ from aider.dump import dump # noqa: F401
DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
class NoModelInfo(Exception):
"""
Exception raised when model information cannot be retrieved.
"""
def __init__(self, model):
super().__init__(check_model_name(model))
class ModelEnvironmentError(Exception):
"""
Exception raised when the environment isn't setup for the model
"""
def __init__(self, message):
super().__init__(message)
@dataclass
class ModelSettings:
name: str
@ -164,30 +146,18 @@ class Model:
max_chat_history_tokens = 1024
weak_model = None
def __init__(self, model, weak_model=None, require_model_info=True, validate_environment=True):
def __init__(self, model, weak_model=None):
self.name = model
# Are all needed keys/params available?
res = litellm.validate_environment(model)
missing_keys = res.get("missing_keys")
keys_in_environment = res.get("keys_in_environment")
if missing_keys:
if validate_environment:
res = f"To use model {model}, please set these environment variables:"
for key in missing_keys:
res += f"- {key}"
raise ModelEnvironmentError(res)
elif not keys_in_environment:
# https://github.com/BerriAI/litellm/issues/3190
print(f"Unable to check environment variables for model {model}")
self.missing_keys = res.get("missing_keys")
self.keys_in_environment = res.get("keys_in_environment")
# Do we have the model_info?
try:
self.info = litellm.get_model_info(model)
except Exception:
if require_model_info:
raise NoModelInfo(model)
self.info = dict()
if self.info.get("max_input_tokens", 0) < 32 * 1024:
@ -199,7 +169,7 @@ class Model:
if weak_model is False:
self.weak_model_name = None
else:
self.get_weak_model(weak_model, require_model_info)
self.get_weak_model(weak_model)
def configure_model_settings(self, model):
for ms in MODEL_SETTINGS:
@ -210,7 +180,9 @@ class Model:
setattr(self, field.name, val)
return # <--
if "llama3" in model and "70b" in model:
model = model.lower()
if ("llama3" in model or "llama-3" in model) and "70b" in model:
self.edit_format = "diff"
self.use_repo_map = True
self.send_undo_reply = True
@ -235,7 +207,7 @@ class Model:
def __str__(self):
return self.name
def get_weak_model(self, provided_weak_model_name, require_model_info):
def get_weak_model(self, provided_weak_model_name):
# If weak_model_name is provided, override the model settings
if provided_weak_model_name:
self.weak_model_name = provided_weak_model_name
@ -251,7 +223,6 @@ class Model:
self.weak_model = Model(
self.weak_model_name,
weak_model=False,
require_model_info=require_model_info,
)
return self.weak_model
@ -313,19 +284,6 @@ class Model:
return img.size
def check_model_name(model):
res = f"Unknown model {model}"
possible_matches = fuzzy_match_models(model)
if possible_matches:
res += ", did you mean one of these?"
for match in possible_matches:
res += "\n- " + match
return res
def fuzzy_match_models(name):
models = litellm.model_cost.keys()