Merge pull request #581 from paul-gauthier/switch-model

New /model command to switch models, /models to search them
This commit is contained in:
paul-gauthier 2024-05-01 07:11:16 -07:00 committed by GitHub
commit 5ea1c5df0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 104 additions and 20 deletions

View file

@ -66,6 +66,7 @@ class Coder:
main_model=None, main_model=None,
edit_format=None, edit_format=None,
io=None, io=None,
from_coder=None,
**kwargs, **kwargs,
): ):
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
@ -76,15 +77,42 @@ class Coder:
if edit_format is None: if edit_format is None:
edit_format = main_model.edit_format edit_format = main_model.edit_format
if from_coder:
use_kwargs = dict(from_coder.original_kwargs) # copy orig kwargs
# If the edit format changes, we can't leave old ASSISTANT
# messages in the chat history. The old edit format will
# confused the new LLM. It may try and imitate it, disobeying
# the system prompt.
done_messages = from_coder.done_messages
if edit_format != from_coder.edit_format and done_messages:
done_messages = from_coder.summarizer.summarize_all(done_messages)
# Bring along context from the old Coder
update = dict(
fnames=from_coder.get_inchat_relative_files(),
done_messages=done_messages,
cur_messages=from_coder.cur_messages,
)
use_kwargs.update(update) # override to complete the switch
use_kwargs.update(kwargs) # override passed kwargs
kwargs = use_kwargs
if edit_format == "diff": if edit_format == "diff":
return EditBlockCoder(main_model, io, **kwargs) res = EditBlockCoder(main_model, io, **kwargs)
elif edit_format == "whole": elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs) res = WholeFileCoder(main_model, io, **kwargs)
elif edit_format == "udiff": elif edit_format == "udiff":
return UnifiedDiffCoder(main_model, io, **kwargs) res = UnifiedDiffCoder(main_model, io, **kwargs)
else: else:
raise ValueError(f"Unknown edit format {edit_format}") raise ValueError(f"Unknown edit format {edit_format}")
res.original_kwargs = dict(kwargs)
return res
def get_announcements(self): def get_announcements(self):
lines = [] lines = []
lines.append(f"Aider v{__version__}") lines.append(f"Aider v{__version__}")
@ -153,6 +181,8 @@ class Coder:
use_git=True, use_git=True,
voice_language=None, voice_language=None,
aider_ignore_file=None, aider_ignore_file=None,
cur_messages=None,
done_messages=None,
): ):
if not fnames: if not fnames:
fnames = [] fnames = []
@ -166,8 +196,16 @@ class Coder:
self.verbose = verbose self.verbose = verbose
self.abs_fnames = set() self.abs_fnames = set()
self.cur_messages = []
self.done_messages = [] if cur_messages:
self.cur_messages = cur_messages
else:
self.cur_messages = []
if done_messages:
self.done_messages = done_messages
else:
self.done_messages = []
self.io = io self.io = io
self.stream = stream self.stream = stream

View file

@ -18,14 +18,14 @@ Once you understand the request you MUST:
You MUST use this *file listing* format: You MUST use this *file listing* format:
path/to/filename.js path/to/filename.js
{fence[0]}javascript {fence[0]}
// entire file content ... // entire file content ...
// ... goes in between // ... goes in between
{fence[1]} {fence[1]}
Every *file listing* MUST use this format: Every *file listing* MUST use this format:
- First line: the filename with any originally provided path - First line: the filename with any originally provided path
- Second line: opening {fence[0]} including the code language - Second line: opening {fence[0]}
- ... entire content of the file ... - ... entire content of the file ...
- Final line: closing {fence[1]} - Final line: closing {fence[1]}

View file

@ -5,15 +5,23 @@ import sys
from pathlib import Path from pathlib import Path
import git import git
import litellm
import openai import openai
from prompt_toolkit.completion import Completion from prompt_toolkit.completion import Completion
from aider import prompts, voice from aider import models, prompts, voice
from aider.scrape import Scraper from aider.scrape import Scraper
from aider.utils import is_image_file from aider.utils import is_image_file
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
litellm.suppress_debug_info = True
class SwitchModel(Exception):
def __init__(self, model):
self.model = model
class Commands: class Commands:
voice = None voice = None
@ -28,6 +36,30 @@ class Commands:
self.voice_language = voice_language self.voice_language = voice_language
def cmd_model(self, args):
"Switch to a new LLM"
model_name = args.strip()
model = models.Model(model_name)
models.sanity_check_models(self.io, model)
raise SwitchModel(model)
def completions_model(self, partial):
models = litellm.model_cost.keys()
for model in models:
if partial.lower() in model.lower():
yield Completion(model, start_position=-len(partial))
def cmd_models(self, args):
"Search the list of available models"
args = args.strip()
if args:
models.print_matching_models(self.io, args)
else:
self.io.tool_output("Please provide a partial model name to search for.")
def cmd_web(self, args): def cmd_web(self, args):
"Use headless selenium to scrape a webpage and add the content to the chat" "Use headless selenium to scrape a webpage and add the content to the chat"
url = args.strip() url = args.strip()
@ -99,6 +131,8 @@ class Commands:
matching_commands, first_word, rest_inp = res matching_commands, first_word, rest_inp = res
if len(matching_commands) == 1: if len(matching_commands) == 1:
return self.do_run(matching_commands[0][1:], rest_inp) return self.do_run(matching_commands[0][1:], rest_inp)
elif first_word in matching_commands:
return self.do_run(first_word[1:], rest_inp)
elif len(matching_commands) > 1: elif len(matching_commands) > 1:
self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}") self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}")
else: else:

View file

@ -10,6 +10,7 @@ from streamlit.web import cli
from aider import __version__, models from aider import __version__, models
from aider.args import get_parser from aider.args import get_parser
from aider.coders import Coder from aider.coders import Coder
from aider.commands import SwitchModel
from aider.io import InputOutput from aider.io import InputOutput
from aider.repo import GitRepo from aider.repo import GitRepo
from aider.versioncheck import check_version from aider.versioncheck import check_version
@ -270,17 +271,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
return 0 if not update_available else 1 return 0 if not update_available else 1
if args.models: if args.models:
matches = models.fuzzy_match_models(args.models) models.print_matching_models(io, args.models)
if matches:
io.tool_output(f'Models which match "{args.models}":')
for model in matches:
fq, m = model
if fq == m:
io.tool_output(f"- {m}")
else:
io.tool_output(f"- {m} ({fq})")
else:
io.tool_output(f'No models match "{args.models}".')
return 0 return 0
if args.git: if args.git:
@ -337,6 +328,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
voice_language=args.voice_language, voice_language=args.voice_language,
aider_ignore_file=args.aiderignore, aider_ignore_file=args.aiderignore,
) )
except ValueError as err: except ValueError as err:
io.tool_error(str(err)) io.tool_error(str(err))
return 1 return 1
@ -398,7 +390,13 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
return 1 return 1
return return
coder.run() while True:
try:
coder.run()
return
except SwitchModel as switch:
coder = Coder.create(main_model=switch.model, io=io, from_coder=coder)
coder.show_announcements()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -409,6 +409,20 @@ def fuzzy_match_models(name):
return list(zip(matching_models, matching_models)) return list(zip(matching_models, matching_models))
def print_matching_models(io, search):
matches = fuzzy_match_models(search)
if matches:
io.tool_output(f'Models which match "{search}":')
for model in matches:
fq, m = model
if fq == m:
io.tool_output(f"- {m}")
else:
io.tool_output(f"- {m} ({fq})")
else:
io.tool_output(f'No models match "{search}".')
def main(): def main():
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("Usage: python models.py <model_name>") print("Usage: python models.py <model_name>")