Merge remote-tracking branch 'origin/main'

This commit is contained in:
Paul Gauthier 2024-05-01 11:08:42 -07:00
commit 5ed9e8cb6d
4 changed files with 102 additions and 18 deletions

View file

@ -66,6 +66,7 @@ class Coder:
main_model=None, main_model=None,
edit_format=None, edit_format=None,
io=None, io=None,
from_coder=None,
**kwargs, **kwargs,
): ):
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
@ -76,15 +77,42 @@ class Coder:
if edit_format is None: if edit_format is None:
edit_format = main_model.edit_format edit_format = main_model.edit_format
if from_coder:
use_kwargs = dict(from_coder.original_kwargs) # copy orig kwargs
# If the edit format changes, we can't leave old ASSISTANT
# messages in the chat history. The old edit format will
# confused the new LLM. It may try and imitate it, disobeying
# the system prompt.
done_messages = from_coder.done_messages
if edit_format != from_coder.edit_format and done_messages:
done_messages = from_coder.summarizer.summarize_all(done_messages)
# Bring along context from the old Coder
update = dict(
fnames=from_coder.get_inchat_relative_files(),
done_messages=done_messages,
cur_messages=from_coder.cur_messages,
)
use_kwargs.update(update) # override to complete the switch
use_kwargs.update(kwargs) # override passed kwargs
kwargs = use_kwargs
if edit_format == "diff": if edit_format == "diff":
return EditBlockCoder(main_model, io, **kwargs) res = EditBlockCoder(main_model, io, **kwargs)
elif edit_format == "whole": elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs) res = WholeFileCoder(main_model, io, **kwargs)
elif edit_format == "udiff": elif edit_format == "udiff":
return UnifiedDiffCoder(main_model, io, **kwargs) res = UnifiedDiffCoder(main_model, io, **kwargs)
else: else:
raise ValueError(f"Unknown edit format {edit_format}") raise ValueError(f"Unknown edit format {edit_format}")
res.original_kwargs = dict(kwargs)
return res
def get_announcements(self): def get_announcements(self):
lines = [] lines = []
lines.append(f"Aider v{__version__}") lines.append(f"Aider v{__version__}")
@ -153,6 +181,8 @@ class Coder:
use_git=True, use_git=True,
voice_language=None, voice_language=None,
aider_ignore_file=None, aider_ignore_file=None,
cur_messages=None,
done_messages=None,
): ):
if not fnames: if not fnames:
fnames = [] fnames = []
@ -166,8 +196,16 @@ class Coder:
self.verbose = verbose self.verbose = verbose
self.abs_fnames = set() self.abs_fnames = set()
self.cur_messages = []
self.done_messages = [] if cur_messages:
self.cur_messages = cur_messages
else:
self.cur_messages = []
if done_messages:
self.done_messages = done_messages
else:
self.done_messages = []
self.io = io self.io = io
self.stream = stream self.stream = stream

View file

@ -5,15 +5,23 @@ import sys
from pathlib import Path from pathlib import Path
import git import git
import litellm
import openai import openai
from prompt_toolkit.completion import Completion from prompt_toolkit.completion import Completion
from aider import prompts, voice from aider import models, prompts, voice
from aider.scrape import Scraper from aider.scrape import Scraper
from aider.utils import is_image_file from aider.utils import is_image_file
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
litellm.suppress_debug_info = True
class SwitchModel(Exception):
def __init__(self, model):
self.model = model
class Commands: class Commands:
voice = None voice = None
@ -28,6 +36,30 @@ class Commands:
self.voice_language = voice_language self.voice_language = voice_language
def cmd_model(self, args):
"Switch to a new LLM"
model_name = args.strip()
model = models.Model(model_name)
models.sanity_check_models(self.io, model)
raise SwitchModel(model)
def completions_model(self, partial):
models = litellm.model_cost.keys()
for model in models:
if partial.lower() in model.lower():
yield Completion(model, start_position=-len(partial))
def cmd_models(self, args):
"Search the list of available models"
args = args.strip()
if args:
models.print_matching_models(self.io, args)
else:
self.io.tool_output("Please provide a partial model name to search for.")
def cmd_web(self, args): def cmd_web(self, args):
"Use headless selenium to scrape a webpage and add the content to the chat" "Use headless selenium to scrape a webpage and add the content to the chat"
url = args.strip() url = args.strip()
@ -99,6 +131,8 @@ class Commands:
matching_commands, first_word, rest_inp = res matching_commands, first_word, rest_inp = res
if len(matching_commands) == 1: if len(matching_commands) == 1:
return self.do_run(matching_commands[0][1:], rest_inp) return self.do_run(matching_commands[0][1:], rest_inp)
elif first_word in matching_commands:
return self.do_run(first_word[1:], rest_inp)
elif len(matching_commands) > 1: elif len(matching_commands) > 1:
self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}") self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}")
else: else:

View file

@ -10,6 +10,7 @@ from streamlit.web import cli
from aider import __version__, models from aider import __version__, models
from aider.args import get_parser from aider.args import get_parser
from aider.coders import Coder from aider.coders import Coder
from aider.commands import SwitchModel
from aider.io import InputOutput from aider.io import InputOutput
from aider.repo import GitRepo from aider.repo import GitRepo
from aider.versioncheck import check_version from aider.versioncheck import check_version
@ -270,17 +271,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
return 0 if not update_available else 1 return 0 if not update_available else 1
if args.models: if args.models:
matches = models.fuzzy_match_models(args.models) models.print_matching_models(io, args.models)
if matches:
io.tool_output(f'Models which match "{args.models}":')
for model in matches:
fq, m = model
if fq == m:
io.tool_output(f"- {m}")
else:
io.tool_output(f"- {m} ({fq})")
else:
io.tool_output(f'No models match "{args.models}".')
return 0 return 0
if args.git: if args.git:
@ -337,6 +328,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
voice_language=args.voice_language, voice_language=args.voice_language,
aider_ignore_file=args.aiderignore, aider_ignore_file=args.aiderignore,
) )
except ValueError as err: except ValueError as err:
io.tool_error(str(err)) io.tool_error(str(err))
return 1 return 1
@ -398,7 +390,13 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
return 1 return 1
return return
coder.run() while True:
try:
coder.run()
return
except SwitchModel as switch:
coder = Coder.create(main_model=switch.model, io=io, from_coder=coder)
coder.show_announcements()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -409,6 +409,20 @@ def fuzzy_match_models(name):
return list(zip(matching_models, matching_models)) return list(zip(matching_models, matching_models))
def print_matching_models(io, search):
matches = fuzzy_match_models(search)
if matches:
io.tool_output(f'Models which match "{search}":')
for model in matches:
fq, m = model
if fq == m:
io.tool_output(f"- {m}")
else:
io.tool_output(f"- {m} ({fq})")
else:
io.tool_output(f'No models match "{search}".')
def main(): def main():
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("Usage: python models.py <model_name>") print("Usage: python models.py <model_name>")