roughed in model switch

This commit is contained in:
Paul Gauthier 2024-04-30 16:22:13 -07:00
parent 050d35790e
commit 304856fc60
2 changed files with 55 additions and 23 deletions

View file

@ -15,6 +15,11 @@ from aider.utils import is_image_file
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
class SwitchModel(Exception):
def __init__(self, model):
self.model = model
class Commands: class Commands:
voice = None voice = None
scraper = None scraper = None
@ -29,6 +34,14 @@ class Commands:
self.voice_language = voice_language self.voice_language = voice_language
def cmd_model(self, args): def cmd_model(self, args):
"Switch to a new LLM"
model_name = args.strip()
model = models.Model(model_name)
models.sanity_check_models(self.io, model)
raise SwitchModel(model)
def cmd_models(self, args):
"Search the list of available models" "Search the list of available models"
args = args.strip() args = args.strip()
@ -36,7 +49,7 @@ class Commands:
if args: if args:
models.print_matching_models(self.io, args) models.print_matching_models(self.io, args)
else: else:
self.io.tool_output("Use `/model <SEARCH>` to show models which match") self.io.tool_output("Please provide a partial model name to search for.")
def cmd_web(self, args): def cmd_web(self, args):
"Use headless selenium to scrape a webpage and add the content to the chat" "Use headless selenium to scrape a webpage and add the content to the chat"
@ -109,6 +122,8 @@ class Commands:
matching_commands, first_word, rest_inp = res matching_commands, first_word, rest_inp = res
if len(matching_commands) == 1: if len(matching_commands) == 1:
return self.do_run(matching_commands[0][1:], rest_inp) return self.do_run(matching_commands[0][1:], rest_inp)
elif first_word in matching_commands:
return self.do_run(first_word[1:], rest_inp)
elif len(matching_commands) > 1: elif len(matching_commands) > 1:
self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}") self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}")
else: else:

View file

@ -10,6 +10,7 @@ from streamlit.web import cli
from aider import __version__, models from aider import __version__, models
from aider.args import get_parser from aider.args import get_parser
from aider.coders import Coder from aider.coders import Coder
from aider.commands import SwitchModel
from aider.io import InputOutput from aider.io import InputOutput
from aider.repo import GitRepo from aider.repo import GitRepo
from aider.versioncheck import check_version from aider.versioncheck import check_version
@ -305,28 +306,31 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
if args.show_model_warnings: if args.show_model_warnings:
models.sanity_check_models(io, main_model) models.sanity_check_models(io, main_model)
coder_kwargs = dict(
main_model=main_model,
edit_format=args.edit_format,
io=io,
##
fnames=fnames,
git_dname=git_dname,
pretty=args.pretty,
show_diffs=args.show_diffs,
auto_commits=args.auto_commits,
dirty_commits=args.dirty_commits,
dry_run=args.dry_run,
map_tokens=args.map_tokens,
verbose=args.verbose,
assistant_output_color=args.assistant_output_color,
code_theme=args.code_theme,
stream=args.stream,
use_git=args.git,
voice_language=args.voice_language,
aider_ignore_file=args.aiderignore,
)
try: try:
coder = Coder.create( coder = Coder.create(**coder_kwargs)
main_model=main_model,
edit_format=args.edit_format,
io=io,
##
fnames=fnames,
git_dname=git_dname,
pretty=args.pretty,
show_diffs=args.show_diffs,
auto_commits=args.auto_commits,
dirty_commits=args.dirty_commits,
dry_run=args.dry_run,
map_tokens=args.map_tokens,
verbose=args.verbose,
assistant_output_color=args.assistant_output_color,
code_theme=args.code_theme,
stream=args.stream,
use_git=args.git,
voice_language=args.voice_language,
aider_ignore_file=args.aiderignore,
)
except ValueError as err: except ValueError as err:
io.tool_error(str(err)) io.tool_error(str(err))
return 1 return 1
@ -388,7 +392,20 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
return 1 return 1
return return
coder.run() while True:
try:
coder.run()
return
except SwitchModel as switch:
kwargs = dict(
main_model=switch.model,
edit_format=None,
fnames=coder.get_inchat_relative_files(),
)
coder_kwargs.update(kwargs)
coder = Coder.create(**coder_kwargs)
coder.show_announcements()
if __name__ == "__main__": if __name__ == "__main__":