From f81b62dfea83cd253fa2168cc3bf689c28267073 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Fri, 19 Apr 2024 14:01:02 -0700 Subject: [PATCH] Added --require-model-info --- aider/commands.py | 10 ++++++++-- aider/main.py | 12 ++++++++++-- aider/models.py | 16 +++++++++++----- aider/repo.py | 8 +++++++- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/aider/commands.py b/aider/commands.py index 737050962..82fcd1de8 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -197,7 +197,10 @@ class Commands: self.io.tool_output("=" * (width + cost_width + 1)) self.io.tool_output(f"${total_cost:7.4f} {fmt(total)} tokens total") - limit = self.coder.main_model.info.get("max_input_tokens") + limit = self.coder.main_model.info.get("max_input_tokens", 0) + if not limit: + return + remaining = limit - total if remaining > 1024: self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window") @@ -207,7 +210,10 @@ class Commands: " /clear to make space)" ) else: - self.io.tool_error(f"{cost_pad}{fmt(remaining)} tokens remaining, window exhausted!") + self.io.tool_error( + f"{cost_pad}{fmt(remaining)} tokens remaining, window exhausted (use /drop or" + " /clear to make space)" + ) self.io.tool_output(f"{cost_pad}{fmt(limit)} tokens max context window size") def cmd_undo(self, args): diff --git a/aider/main.py b/aider/main.py index 1bca3eae8..d9efd4274 100644 --- a/aider/main.py +++ b/aider/main.py @@ -276,6 +276,12 @@ def main(argv=None, input=None, output=None, force_git_root=None): " depends on --model)" ), ) + model_group.add_argument( + "--require-model-info", + action=argparse.BooleanOptionalAction, + default=True, + help="Only work with models that have meta-data available (default: True)", + ) model_group.add_argument( "--map-tokens", type=int, @@ -606,13 +612,15 @@ def main(argv=None, input=None, output=None, force_git_root=None): for key in missing_keys: io.tool_error(f"- {key}") return 1 - elif not res["keys_in_environment"]: + elif not res["keys_in_environment"] and args.require_model_info: io.tool_error(models.check_model_name(args.model)) return 1 # Check in advance that we have model metadata try: - main_model = models.Model(args.model, weak_model=args.weak_model) + main_model = models.Model( + args.model, weak_model=args.weak_model, require_model_info=args.require_model_info + ) except models.NoModelInfo as err: io.tool_error(str(err)) return 1 diff --git a/aider/models.py b/aider/models.py index 44683dad6..406c77cf7 100644 --- a/aider/models.py +++ b/aider/models.py @@ -123,13 +123,15 @@ class Model: max_chat_history_tokens = 1024 weak_model = None - def __init__(self, model, weak_model=None): + def __init__(self, model, weak_model=None, require_model_info=True): self.name = model try: self.info = litellm.get_model_info(model) except Exception: - raise NoModelInfo(model) + if require_model_info: + raise NoModelInfo(model) + self.info = dict() if self.info.get("max_input_tokens", 0) < 32 * 1024: self.max_chat_history_tokens = 1024 @@ -137,7 +139,7 @@ class Model: self.max_chat_history_tokens = 2 * 1024 self.configure_model_settings(model) - self.get_weak_model(weak_model) + self.get_weak_model(weak_model, require_model_info) def configure_model_settings(self, model): for ms in MODEL_SETTINGS: @@ -161,7 +163,7 @@ class Model: def __str__(self): return self.name - def get_weak_model(self, provided_weak_model_name): + def get_weak_model(self, provided_weak_model_name, require_model_info): # If weak_model_name is provided, override the model settings if provided_weak_model_name: self.weak_model_name = provided_weak_model_name @@ -170,7 +172,11 @@ class Model: self.weak_model = self return - self.weak_model = Model(self.weak_model_name) + self.weak_model = Model( + self.weak_model_name, + weak_model=self.weak_model_name, + require_model_info=require_model_info, + ) return self.weak_model def commit_message_models(self): diff --git a/aider/repo.py b/aider/repo.py index aaa4d4fb9..bc76e6d0d 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -22,7 +22,13 @@ class GitRepo: if models: self.models = models else: - self.models = [Model(DEFAULT_WEAK_MODEL_NAME)] + self.models = [ + Model( + DEFAULT_WEAK_MODEL_NAME, + weak_model=DEFAULT_WEAK_MODEL_NAME, + require_model_info=False, + ) + ] if git_dname: check_fnames = [git_dname]