From 304856fc600d3774528481b6f588a7a6bc334440 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Tue, 30 Apr 2024 16:22:13 -0700 Subject: [PATCH] roughed in model switch --- aider/commands.py | 17 ++++++++++++- aider/main.py | 61 ++++++++++++++++++++++++++++++----------------- 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/aider/commands.py b/aider/commands.py index ed81c54ca..958902977 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -15,6 +15,11 @@ from aider.utils import is_image_file from .dump import dump # noqa: F401 +class SwitchModel(Exception): + def __init__(self, model): + self.model = model + + class Commands: voice = None scraper = None @@ -29,6 +34,14 @@ class Commands: self.voice_language = voice_language def cmd_model(self, args): + "Switch to a new LLM" + + model_name = args.strip() + model = models.Model(model_name) + models.sanity_check_models(self.io, model) + raise SwitchModel(model) + + def cmd_models(self, args): "Search the list of available models" args = args.strip() @@ -36,7 +49,7 @@ class Commands: if args: models.print_matching_models(self.io, args) else: - self.io.tool_output("Use `/model ` to show models which match") + self.io.tool_output("Please provide a partial model name to search for.") def cmd_web(self, args): "Use headless selenium to scrape a webpage and add the content to the chat" @@ -109,6 +122,8 @@ class Commands: matching_commands, first_word, rest_inp = res if len(matching_commands) == 1: return self.do_run(matching_commands[0][1:], rest_inp) + elif first_word in matching_commands: + return self.do_run(first_word[1:], rest_inp) elif len(matching_commands) > 1: self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}") else: diff --git a/aider/main.py b/aider/main.py index 11e0143cb..fd6195f5b 100644 --- a/aider/main.py +++ b/aider/main.py @@ -10,6 +10,7 @@ from streamlit.web import cli from aider import __version__, models from aider.args import get_parser from aider.coders import Coder +from aider.commands import SwitchModel from aider.io import InputOutput from aider.repo import GitRepo from aider.versioncheck import check_version @@ -305,28 +306,31 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F if args.show_model_warnings: models.sanity_check_models(io, main_model) + coder_kwargs = dict( + main_model=main_model, + edit_format=args.edit_format, + io=io, + ## + fnames=fnames, + git_dname=git_dname, + pretty=args.pretty, + show_diffs=args.show_diffs, + auto_commits=args.auto_commits, + dirty_commits=args.dirty_commits, + dry_run=args.dry_run, + map_tokens=args.map_tokens, + verbose=args.verbose, + assistant_output_color=args.assistant_output_color, + code_theme=args.code_theme, + stream=args.stream, + use_git=args.git, + voice_language=args.voice_language, + aider_ignore_file=args.aiderignore, + ) + try: - coder = Coder.create( - main_model=main_model, - edit_format=args.edit_format, - io=io, - ## - fnames=fnames, - git_dname=git_dname, - pretty=args.pretty, - show_diffs=args.show_diffs, - auto_commits=args.auto_commits, - dirty_commits=args.dirty_commits, - dry_run=args.dry_run, - map_tokens=args.map_tokens, - verbose=args.verbose, - assistant_output_color=args.assistant_output_color, - code_theme=args.code_theme, - stream=args.stream, - use_git=args.git, - voice_language=args.voice_language, - aider_ignore_file=args.aiderignore, - ) + coder = Coder.create(**coder_kwargs) + except ValueError as err: io.tool_error(str(err)) return 1 @@ -388,7 +392,20 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F return 1 return - coder.run() + while True: + try: + coder.run() + return + except SwitchModel as switch: + kwargs = dict( + main_model=switch.model, + edit_format=None, + fnames=coder.get_inchat_relative_files(), + ) + coder_kwargs.update(kwargs) + + coder = Coder.create(**coder_kwargs) + coder.show_announcements() if __name__ == "__main__":