Use SwitchModel to revert from /help and /ask

This commit is contained in:
Paul Gauthier 2024-08-10 14:41:44 -07:00
parent 4f16eb0856
commit 156f11248f
2 changed files with 33 additions and 15 deletions

View file

@ -821,13 +821,23 @@ class Commands:
"""
user_msg += "\n".join(self.coder.get_announcements()) + "\n"
assistant_msg = coder.run(user_msg, preproc=False)
coder.run(user_msg, preproc=False)
self.coder.cur_messages += [
dict(role="user", content=user_msg),
dict(role="assistant", content=assistant_msg),
]
self.coder.total_cost = coder.total_cost
if self.coder.repo_map:
map_tokens = self.coder.repo_map.max_map_tokens
map_mul_no_files = self.coder.repo_map.map_mul_no_files
else:
map_tokens = 0
map_mul_no_files = 1
raise SwitchCoder(
edit_format=self.coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
map_tokens=map_tokens,
map_mul_no_files=map_mul_no_files,
show_announcements=False,
)
def clone(self):
return Commands(
@ -846,7 +856,7 @@ class Commands:
from aider.coders import Coder
chat_coder = Coder.create(
coder = Coder.create(
io=self.io,
from_coder=self.coder,
edit_format="ask",
@ -854,13 +864,14 @@ class Commands:
)
user_msg = args
assistant_msg = chat_coder.run(user_msg)
coder.run(user_msg)
self.coder.cur_messages += [
dict(role="user", content=user_msg),
dict(role="assistant", content=assistant_msg),
]
self.coder.total_cost = chat_coder.total_cost
raise SwitchCoder(
edit_format=self.coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
show_announcements=False,
)
def get_help_md(self):
"Show help about all commands in markdown"

View file

@ -620,8 +620,15 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
coder.run()
return
except SwitchCoder as switch:
coder = Coder.create(io=io, from_coder=coder, **switch.kwargs)
coder.show_announcements()
kwargs = dict(io=io, from_coder=coder)
kwargs.update(switch.kwargs)
if "show_announcements" in kwargs:
del kwargs["show_announcements"]
coder = Coder.create(**kwargs)
if switch.kwargs.get("show_announcements") is not False:
coder.show_announcements()
def load_slow_imports():