diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 2332d2e60..cfbd9576e 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -136,7 +136,11 @@ class Coder: self.main_model = main_model - self.io.tool_output(f"Model: {main_model.name} using {self.edit_format} edit format") + weak_model = main_model.weak_model + self.io.tool_output( + f"Models: {main_model.name} with {self.edit_format} edit format, weak model" + f" {weak_model.name}" + ) self.show_diffs = show_diffs @@ -212,7 +216,7 @@ class Coder: self.io.tool_output(f"Added {fname} to the chat.") self.summarizer = ChatSummary( - self.main_model.get_weak_model(), + self.main_model.weak_model, self.main_model.max_chat_history_tokens, ) diff --git a/aider/history.py b/aider/history.py index 9ed866554..3acd368b5 100644 --- a/aider/history.py +++ b/aider/history.py @@ -124,7 +124,7 @@ def main(): assistant.append(line) - summarizer = ChatSummary(models.Model.get_weak_model()) + summarizer = ChatSummary(models.Model(models.DEFAULT_WEAK_MODEL_NAME)) summary = summarizer.summarize(messages[-40:]) dump(summary) diff --git a/aider/main.py b/aider/main.py index 61c76a8d6..dedff7791 100644 --- a/aider/main.py +++ b/aider/main.py @@ -196,7 +196,7 @@ def main(argv=None, input=None, output=None, force_git_root=None): const=default_4_turbo_model, help=f"Use {default_4_turbo_model} model for the main chat", ) - default_3_model_name = "gpt-3.5-turbo-0125" + default_3_model_name = "gpt-3.5-turbo" core_group.add_argument( "--35turbo", "--35-turbo", @@ -252,6 +252,15 @@ def main(argv=None, input=None, output=None, force_git_root=None): default=None, help="Specify what edit format GPT should use (default depends on model)", ) + core_group.add_argument( + "--weak-model", + metavar="WEAK_MODEL", + default=None, + help=( + "Specify the model to use for commit messages and chat history summarization (default" + " depends on --model)" + ), + ) model_group.add_argument( "--map-tokens", type=int, @@ -583,10 +592,9 @@ def main(argv=None, input=None, output=None, force_git_root=None): # Check in advance that we have model metadata try: - main_model = models.Model(args.model) + main_model = models.Model(args.model, weak_model=args.weak_model) except models.NoModelInfo as err: - io.tool_error(f"Unknown model {args.model}.") - io.tool_error(str(err)) + io.tool_error(f"Unknown model {err}.") return 1 try: diff --git a/aider/models.py b/aider/models.py index bd48d486b..bf79ba7e3 100644 --- a/aider/models.py +++ b/aider/models.py @@ -1,7 +1,6 @@ import json import math from dataclasses import dataclass, fields -from typing import Optional import litellm from PIL import Image @@ -9,6 +8,7 @@ from PIL import Image from aider.dump import dump # noqa: F401 DEFAULT_MODEL_NAME = "gpt-4-1106-preview" +DEFAULT_WEAK_MODEL_NAME = "gpt-3.5-turbo" class NoModelInfo(Exception): @@ -16,15 +16,15 @@ class NoModelInfo(Exception): Exception raised when model information cannot be retrieved. """ - def __init__(self, message: Optional[str] = None): - super().__init__(message or "No model information available.") + def __init__(self, model): + super().__init__(model) @dataclass class ModelSettings: name: str edit_format: str - weak_model_name: str = "gpt-3.5-turbo-0125" + weak_model_name: str = DEFAULT_WEAK_MODEL_NAME use_repo_map: bool = False send_undo_reply: bool = False accepts_images: bool = False @@ -112,23 +112,22 @@ MODEL_SETTINGS = [ class Model: name = None - weak_model_name = "gpt-3.5-turbo-0125" - edit_format = "whole" use_repo_map = False send_undo_reply = False accepts_images = False + weak_model_name = DEFAULT_WEAK_MODEL_NAME max_chat_history_tokens = 1024 weak_model = None - def __init__(self, model): + def __init__(self, model, weak_model=None): self.name = model try: self.info = litellm.get_model_info(model) - except Exception as err: - raise NoModelInfo(str(err)) + except Exception: + raise NoModelInfo(model) if self.info.get("max_input_tokens", 0) < 32 * 1024: self.max_chat_history_tokens = 1024 @@ -136,6 +135,7 @@ class Model: self.max_chat_history_tokens = 2 * 1024 self.configure_model_settings(model) + self.get_weak_model(weak_model) def configure_model_settings(self, model): for ms in MODEL_SETTINGS: @@ -159,13 +159,25 @@ class Model: def __str__(self): return self.name - def get_weak_model(self): - if not self.weak_model: - self.weak_model = Model(self.weak_model_name) + def get_weak_model(self, provided_weak_model_name): + # If weak_model_name is provided, override the model settings + if provided_weak_model_name: + self.weak_model_name = provided_weak_model_name + + if self.weak_model_name == self.name: + self.weak_model = self + return + + try: + self.info = litellm.get_model_info(self.weak_model_name) + except Exception: + raise NoModelInfo(self.weak_model_name) + + self.weak_model = Model(self.weak_model_name) return self.weak_model def commit_message_models(self): - return [self.get_weak_model()] + return [self.weak_model] def tokenizer(self, text): return litellm.encode(model=self.name, text=text)