roughed in model switch

This commit is contained in:
Paul Gauthier 2024-04-30 16:22:13 -07:00
parent 050d35790e
commit 304856fc60
2 changed files with 55 additions and 23 deletions

View file

@ -10,6 +10,7 @@ from streamlit.web import cli
from aider import __version__, models
from aider.args import get_parser
from aider.coders import Coder
from aider.commands import SwitchModel
from aider.io import InputOutput
from aider.repo import GitRepo
from aider.versioncheck import check_version
@ -305,28 +306,31 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
if args.show_model_warnings:
models.sanity_check_models(io, main_model)
coder_kwargs = dict(
main_model=main_model,
edit_format=args.edit_format,
io=io,
##
fnames=fnames,
git_dname=git_dname,
pretty=args.pretty,
show_diffs=args.show_diffs,
auto_commits=args.auto_commits,
dirty_commits=args.dirty_commits,
dry_run=args.dry_run,
map_tokens=args.map_tokens,
verbose=args.verbose,
assistant_output_color=args.assistant_output_color,
code_theme=args.code_theme,
stream=args.stream,
use_git=args.git,
voice_language=args.voice_language,
aider_ignore_file=args.aiderignore,
)
try:
coder = Coder.create(
main_model=main_model,
edit_format=args.edit_format,
io=io,
##
fnames=fnames,
git_dname=git_dname,
pretty=args.pretty,
show_diffs=args.show_diffs,
auto_commits=args.auto_commits,
dirty_commits=args.dirty_commits,
dry_run=args.dry_run,
map_tokens=args.map_tokens,
verbose=args.verbose,
assistant_output_color=args.assistant_output_color,
code_theme=args.code_theme,
stream=args.stream,
use_git=args.git,
voice_language=args.voice_language,
aider_ignore_file=args.aiderignore,
)
coder = Coder.create(**coder_kwargs)
except ValueError as err:
io.tool_error(str(err))
return 1
@ -388,7 +392,20 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
return 1
return
coder.run()
while True:
try:
coder.run()
return
except SwitchModel as switch:
kwargs = dict(
main_model=switch.model,
edit_format=None,
fnames=coder.get_inchat_relative_files(),
)
coder_kwargs.update(kwargs)
coder = Coder.create(**coder_kwargs)
coder.show_announcements()
if __name__ == "__main__":