From 922559a15a3eb819680241689ce84c97326f0563 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Fri, 19 Apr 2024 11:17:33 -0700 Subject: [PATCH] Refactored error handling to display model name in case of unknown model. --- aider/main.py | 4 ++-- aider/models.py | 28 ++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/aider/main.py b/aider/main.py index dedff7791..93933f4ba 100644 --- a/aider/main.py +++ b/aider/main.py @@ -587,14 +587,14 @@ def main(argv=None, input=None, output=None, force_git_root=None): io.tool_error(f"- {key}") return 1 elif not res["keys_in_environment"]: - io.tool_error(f"Unknown model {args.model}.") + io.tool_error(models.check_model_name(args.model)) return 1 # Check in advance that we have model metadata try: main_model = models.Model(args.model, weak_model=args.weak_model) except models.NoModelInfo as err: - io.tool_error(f"Unknown model {err}.") + io.tool_error(str(err)) return 1 try: diff --git a/aider/models.py b/aider/models.py index f2d85664a..e363c8def 100644 --- a/aider/models.py +++ b/aider/models.py @@ -1,3 +1,5 @@ +import difflib +import sys import json import math from dataclasses import dataclass, fields @@ -17,7 +19,7 @@ class NoModelInfo(Exception): """ def __init__(self, model): - super().__init__(model) + super().__init__(check_model_name(model)) @dataclass @@ -234,24 +236,34 @@ class Model: return img.size -import difflib +def check_model_name(model): + res = f"Unknown model: {model}" + + possible_matches = fuzzy_match_models(model) + + if possible_matches: + res += '\n\nDid you mean one of these:\n' + for match in possible_matches: + res += '\n- ' + match + + return res def fuzzy_match_models(name): - models = litellm.most_cost.keys() - + models = litellm.model_cost.keys() + # Check for exact match first if name in models: return [name] - + # Check for models containing the name matching_models = [model for model in models if name in model] - + # If no matches found, check for slight misspellings if not matching_models: matching_models = difflib.get_close_matches(name, models, n=3, cutoff=0.8) - + return matching_models -import sys + def main(): if len(sys.argv) != 2: