feat: add junior_edit_format parameter to get_junior_model method

This commit is contained in:
Paul Gauthier 2024-09-25 11:37:41 -07:00 committed by Paul Gauthier (aider)
parent 0ded63cd31
commit 24c15db8d7

View file

@ -602,7 +602,6 @@ class Model(ModelSettings):
self.max_chat_history_tokens = 1024 self.max_chat_history_tokens = 1024
self.weak_model = None self.weak_model = None
self.junior_model = None self.junior_model = None
self.junior_edit_format = junior_edit_format
self.info = self.get_model_info(model) self.info = self.get_model_info(model)
@ -626,7 +625,7 @@ class Model(ModelSettings):
if junior_model is False: if junior_model is False:
self.junior_model_name = None self.junior_model_name = None
else: else:
self.get_junior_model(junior_model) self.get_junior_model(junior_model, junior_edit_format)
def get_model_info(self, model): def get_model_info(self, model):
return get_model_info(model) return get_model_info(model)
@ -699,10 +698,12 @@ class Model(ModelSettings):
def commit_message_models(self): def commit_message_models(self):
return [self.weak_model, self] return [self.weak_model, self]
def get_junior_model(self, provided_junior_model_name): def get_junior_model(self, provided_junior_model_name, junior_edit_format):
# If junior_model_name is provided, override the model settings # If junior_model_name is provided, override the model settings
if provided_junior_model_name: if provided_junior_model_name:
self.junior_model_name = provided_junior_model_name self.junior_model_name = provided_junior_model_name
if junior_edit_format:
self.junior_edit_format = junior_edit_format
if not self.junior_model_name: if not self.junior_model_name:
self.junior_model = self self.junior_model = self
@ -714,14 +715,9 @@ class Model(ModelSettings):
self.junior_model = Model( self.junior_model = Model(
self.junior_model_name, self.junior_model_name,
weak_model=False,
junior_model=False, junior_model=False,
) )
# Use the provided junior_edit_format if available, otherwise use the ModelSettings value
if self.junior_edit_format is None:
self.junior_edit_format = self.junior_edit_format or self.junior_model.edit_format
return self.junior_model return self.junior_model
def tokenizer(self, text): def tokenizer(self, text):