roughed in model switch

This commit is contained in:
Paul Gauthier 2024-04-30 16:22:13 -07:00
parent 050d35790e
commit 304856fc60
2 changed files with 55 additions and 23 deletions

View file

@ -15,6 +15,11 @@ from aider.utils import is_image_file
from .dump import dump # noqa: F401
class SwitchModel(Exception):
def __init__(self, model):
self.model = model
class Commands:
voice = None
scraper = None
@ -29,6 +34,14 @@ class Commands:
self.voice_language = voice_language
def cmd_model(self, args):
"Switch to a new LLM"
model_name = args.strip()
model = models.Model(model_name)
models.sanity_check_models(self.io, model)
raise SwitchModel(model)
def cmd_models(self, args):
"Search the list of available models"
args = args.strip()
@ -36,7 +49,7 @@ class Commands:
if args:
models.print_matching_models(self.io, args)
else:
self.io.tool_output("Use `/model <SEARCH>` to show models which match")
self.io.tool_output("Please provide a partial model name to search for.")
def cmd_web(self, args):
"Use headless selenium to scrape a webpage and add the content to the chat"
@ -109,6 +122,8 @@ class Commands:
matching_commands, first_word, rest_inp = res
if len(matching_commands) == 1:
return self.do_run(matching_commands[0][1:], rest_inp)
elif first_word in matching_commands:
return self.do_run(first_word[1:], rest_inp)
elif len(matching_commands) > 1:
self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}")
else:

View file

@ -10,6 +10,7 @@ from streamlit.web import cli
from aider import __version__, models
from aider.args import get_parser
from aider.coders import Coder
from aider.commands import SwitchModel
from aider.io import InputOutput
from aider.repo import GitRepo
from aider.versioncheck import check_version
@ -305,8 +306,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
if args.show_model_warnings:
models.sanity_check_models(io, main_model)
try:
coder = Coder.create(
coder_kwargs = dict(
main_model=main_model,
edit_format=args.edit_format,
io=io,
@ -327,6 +327,10 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
voice_language=args.voice_language,
aider_ignore_file=args.aiderignore,
)
try:
coder = Coder.create(**coder_kwargs)
except ValueError as err:
io.tool_error(str(err))
return 1
@ -388,7 +392,20 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
return 1
return
while True:
try:
coder.run()
return
except SwitchModel as switch:
kwargs = dict(
main_model=switch.model,
edit_format=None,
fnames=coder.get_inchat_relative_files(),
)
coder_kwargs.update(kwargs)
coder = Coder.create(**coder_kwargs)
coder.show_announcements()
if __name__ == "__main__":