Added --require-model-info

This commit is contained in:
Paul Gauthier 2024-04-19 14:01:02 -07:00
parent aac110f078
commit f81b62dfea
4 changed files with 36 additions and 10 deletions

View file

@ -123,13 +123,15 @@ class Model:
max_chat_history_tokens = 1024
weak_model = None
def __init__(self, model, weak_model=None):
def __init__(self, model, weak_model=None, require_model_info=True):
self.name = model
try:
self.info = litellm.get_model_info(model)
except Exception:
raise NoModelInfo(model)
if require_model_info:
raise NoModelInfo(model)
self.info = dict()
if self.info.get("max_input_tokens", 0) < 32 * 1024:
self.max_chat_history_tokens = 1024
@ -137,7 +139,7 @@ class Model:
self.max_chat_history_tokens = 2 * 1024
self.configure_model_settings(model)
self.get_weak_model(weak_model)
self.get_weak_model(weak_model, require_model_info)
def configure_model_settings(self, model):
for ms in MODEL_SETTINGS:
@ -161,7 +163,7 @@ class Model:
def __str__(self):
return self.name
def get_weak_model(self, provided_weak_model_name):
def get_weak_model(self, provided_weak_model_name, require_model_info):
# If weak_model_name is provided, override the model settings
if provided_weak_model_name:
self.weak_model_name = provided_weak_model_name
@ -170,7 +172,11 @@ class Model:
self.weak_model = self
return
self.weak_model = Model(self.weak_model_name)
self.weak_model = Model(
self.weak_model_name,
weak_model=self.weak_model_name,
require_model_info=require_model_info,
)
return self.weak_model
def commit_message_models(self):