feat: Add --junior-model argument and plumb it into Model()

This commit is contained in:
Paul Gauthier (aider) 2024-09-25 11:03:15 -07:00
parent 11cdc4175f
commit 926b3c9240
3 changed files with 36 additions and 2 deletions

View file

@ -597,10 +597,11 @@ def get_model_info(model):
class Model(ModelSettings):
def __init__(self, model, weak_model=None):
def __init__(self, model, weak_model=None, junior_model=None):
self.name = model
self.max_chat_history_tokens = 1024
self.weak_model = None
self.junior_model = None
self.info = self.get_model_info(model)
@ -620,6 +621,11 @@ class Model(ModelSettings):
self.weak_model_name = None
else:
self.get_weak_model(weak_model)
if junior_model is False:
self.junior_model_name = None
else:
self.get_junior_model(junior_model)
def get_model_info(self, model):
return get_model_info(model)
@ -692,6 +698,26 @@ class Model(ModelSettings):
def commit_message_models(self):
return [self.weak_model, self]
def get_junior_model(self, provided_junior_model_name):
# If junior_model_name is provided, override the model settings
if provided_junior_model_name:
self.junior_model_name = provided_junior_model_name
if not self.junior_model_name:
self.junior_model = self
return
if self.junior_model_name == self.name:
self.junior_model = self
return
self.junior_model = Model(
self.junior_model_name,
weak_model=False,
junior_model=False,
)
return self.junior_model
def tokenizer(self, text):
return litellm.encode(model=self.name, text=text)