From 926b3c9240d271fb8379951578d4bb831e3739d9 Mon Sep 17 00:00:00 2001 From: "Paul Gauthier (aider)" Date: Wed, 25 Sep 2024 11:03:15 -0700 Subject: [PATCH] feat: Add --junior-model argument and plumb it into Model() --- aider/args.py | 8 ++++++++ aider/main.py | 2 +- aider/models.py | 28 +++++++++++++++++++++++++++- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/aider/args.py b/aider/args.py index cdcf8531d..94e814804 100644 --- a/aider/args.py +++ b/aider/args.py @@ -190,6 +190,14 @@ def get_parser(default_config_files, git_root): " depends on --model)" ), ) + group.add_argument( + "--junior-model", + metavar="JUNIOR_MODEL", + default=None, + help=( + "Specify the model to use for junior tasks (default depends on --model)" + ), + ) group.add_argument( "--show-model-warnings", action=argparse.BooleanOptionalAction, diff --git a/aider/main.py b/aider/main.py index e27948f43..a1d64b5b7 100644 --- a/aider/main.py +++ b/aider/main.py @@ -533,7 +533,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F if os.environ.get("ANTHROPIC_API_KEY"): args.model = "claude-3-5-sonnet-20240620" - main_model = models.Model(args.model, weak_model=args.weak_model) + main_model = models.Model(args.model, weak_model=args.weak_model, junior_model=args.junior_model) if args.verbose: io.tool_output("Model info:") diff --git a/aider/models.py b/aider/models.py index e4aafaee4..47b3eb436 100644 --- a/aider/models.py +++ b/aider/models.py @@ -597,10 +597,11 @@ def get_model_info(model): class Model(ModelSettings): - def __init__(self, model, weak_model=None): + def __init__(self, model, weak_model=None, junior_model=None): self.name = model self.max_chat_history_tokens = 1024 self.weak_model = None + self.junior_model = None self.info = self.get_model_info(model) @@ -620,6 +621,11 @@ class Model(ModelSettings): self.weak_model_name = None else: self.get_weak_model(weak_model) + + if junior_model is False: + self.junior_model_name = None + else: + self.get_junior_model(junior_model) def get_model_info(self, model): return get_model_info(model) @@ -692,6 +698,26 @@ class Model(ModelSettings): def commit_message_models(self): return [self.weak_model, self] + def get_junior_model(self, provided_junior_model_name): + # If junior_model_name is provided, override the model settings + if provided_junior_model_name: + self.junior_model_name = provided_junior_model_name + + if not self.junior_model_name: + self.junior_model = self + return + + if self.junior_model_name == self.name: + self.junior_model = self + return + + self.junior_model = Model( + self.junior_model_name, + weak_model=False, + junior_model=False, + ) + return self.junior_model + def tokenizer(self, text): return litellm.encode(model=self.name, text=text)