refactor: Extract get_model_info into a standalone function

This commit is contained in:
Paul Gauthier (aider) 2024-08-25 08:19:15 -07:00
parent 539a657624
commit 5473d99e13

View file

@ -427,33 +427,7 @@ MODEL_SETTINGS = [
]
class Model(ModelSettings):
def __init__(self, model, weak_model=None):
self.name = model
self.max_chat_history_tokens = 1024
self.weak_model = None
self.info = self.get_model_info(model)
dump(self.info)
# Are all needed keys/params available?
res = self.validate_environment()
self.missing_keys = res.get("missing_keys")
self.keys_in_environment = res.get("keys_in_environment")
max_input_tokens = self.info.get("max_input_tokens") or 0
if max_input_tokens < 32 * 1024:
self.max_chat_history_tokens = 1024
else:
self.max_chat_history_tokens = 2 * 1024
self.configure_model_settings(model)
if weak_model is False:
self.weak_model_name = None
else:
self.get_weak_model(weak_model)
def get_model_info(self, model):
def get_model_info(model):
if litellm._lazy_module:
# Do it the slow way...
try:
@ -511,6 +485,36 @@ class Model(ModelSettings):
except Exception:
return dict()
class Model(ModelSettings):
def __init__(self, model, weak_model=None):
self.name = model
self.max_chat_history_tokens = 1024
self.weak_model = None
self.info = self.get_model_info(model)
dump(self.info)
# Are all needed keys/params available?
res = self.validate_environment()
self.missing_keys = res.get("missing_keys")
self.keys_in_environment = res.get("keys_in_environment")
max_input_tokens = self.info.get("max_input_tokens") or 0
if max_input_tokens < 32 * 1024:
self.max_chat_history_tokens = 1024
else:
self.max_chat_history_tokens = 2 * 1024
self.configure_model_settings(model)
if weak_model is False:
self.weak_model_name = None
else:
self.get_weak_model(weak_model)
def get_model_info(self, model):
return get_model_info(model)
def configure_model_settings(self, model):
for ms in MODEL_SETTINGS:
# direct match, or match "provider/<model>"