mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-05 04:05:04 +00:00
feat: Add --junior-model argument and plumb it into Model()
This commit is contained in:
parent
11cdc4175f
commit
926b3c9240
3 changed files with 36 additions and 2 deletions
|
@ -190,6 +190,14 @@ def get_parser(default_config_files, git_root):
|
|||
" depends on --model)"
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--junior-model",
|
||||
metavar="JUNIOR_MODEL",
|
||||
default=None,
|
||||
help=(
|
||||
"Specify the model to use for junior tasks (default depends on --model)"
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--show-model-warnings",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
|
|
|
@ -533,7 +533,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
|
|||
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||
args.model = "claude-3-5-sonnet-20240620"
|
||||
|
||||
main_model = models.Model(args.model, weak_model=args.weak_model)
|
||||
main_model = models.Model(args.model, weak_model=args.weak_model, junior_model=args.junior_model)
|
||||
|
||||
if args.verbose:
|
||||
io.tool_output("Model info:")
|
||||
|
|
|
@ -597,10 +597,11 @@ def get_model_info(model):
|
|||
|
||||
|
||||
class Model(ModelSettings):
|
||||
def __init__(self, model, weak_model=None):
|
||||
def __init__(self, model, weak_model=None, junior_model=None):
|
||||
self.name = model
|
||||
self.max_chat_history_tokens = 1024
|
||||
self.weak_model = None
|
||||
self.junior_model = None
|
||||
|
||||
self.info = self.get_model_info(model)
|
||||
|
||||
|
@ -620,6 +621,11 @@ class Model(ModelSettings):
|
|||
self.weak_model_name = None
|
||||
else:
|
||||
self.get_weak_model(weak_model)
|
||||
|
||||
if junior_model is False:
|
||||
self.junior_model_name = None
|
||||
else:
|
||||
self.get_junior_model(junior_model)
|
||||
|
||||
def get_model_info(self, model):
|
||||
return get_model_info(model)
|
||||
|
@ -692,6 +698,26 @@ class Model(ModelSettings):
|
|||
def commit_message_models(self):
|
||||
return [self.weak_model, self]
|
||||
|
||||
def get_junior_model(self, provided_junior_model_name):
|
||||
# If junior_model_name is provided, override the model settings
|
||||
if provided_junior_model_name:
|
||||
self.junior_model_name = provided_junior_model_name
|
||||
|
||||
if not self.junior_model_name:
|
||||
self.junior_model = self
|
||||
return
|
||||
|
||||
if self.junior_model_name == self.name:
|
||||
self.junior_model = self
|
||||
return
|
||||
|
||||
self.junior_model = Model(
|
||||
self.junior_model_name,
|
||||
weak_model=False,
|
||||
junior_model=False,
|
||||
)
|
||||
return self.junior_model
|
||||
|
||||
def tokenizer(self, text):
|
||||
return litellm.encode(model=self.name, text=text)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue