Use SwitchModel to revert from /help and /ask

This commit is contained in:
Paul Gauthier 2024-08-10 14:41:44 -07:00
parent 4f16eb0856
commit 156f11248f
2 changed files with 33 additions and 15 deletions

View file

@ -821,13 +821,23 @@ class Commands:
""" """
user_msg += "\n".join(self.coder.get_announcements()) + "\n" user_msg += "\n".join(self.coder.get_announcements()) + "\n"
assistant_msg = coder.run(user_msg, preproc=False) coder.run(user_msg, preproc=False)
self.coder.cur_messages += [ if self.coder.repo_map:
dict(role="user", content=user_msg), map_tokens = self.coder.repo_map.max_map_tokens
dict(role="assistant", content=assistant_msg), map_mul_no_files = self.coder.repo_map.map_mul_no_files
] else:
self.coder.total_cost = coder.total_cost map_tokens = 0
map_mul_no_files = 1
raise SwitchCoder(
edit_format=self.coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
map_tokens=map_tokens,
map_mul_no_files=map_mul_no_files,
show_announcements=False,
)
def clone(self): def clone(self):
return Commands( return Commands(
@ -846,7 +856,7 @@ class Commands:
from aider.coders import Coder from aider.coders import Coder
chat_coder = Coder.create( coder = Coder.create(
io=self.io, io=self.io,
from_coder=self.coder, from_coder=self.coder,
edit_format="ask", edit_format="ask",
@ -854,13 +864,14 @@ class Commands:
) )
user_msg = args user_msg = args
assistant_msg = chat_coder.run(user_msg) coder.run(user_msg)
self.coder.cur_messages += [ raise SwitchCoder(
dict(role="user", content=user_msg), edit_format=self.coder.edit_format,
dict(role="assistant", content=assistant_msg), summarize_from_coder=False,
] from_coder=coder,
self.coder.total_cost = chat_coder.total_cost show_announcements=False,
)
def get_help_md(self): def get_help_md(self):
"Show help about all commands in markdown" "Show help about all commands in markdown"

View file

@ -620,8 +620,15 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
coder.run() coder.run()
return return
except SwitchCoder as switch: except SwitchCoder as switch:
coder = Coder.create(io=io, from_coder=coder, **switch.kwargs) kwargs = dict(io=io, from_coder=coder)
coder.show_announcements() kwargs.update(switch.kwargs)
if "show_announcements" in kwargs:
del kwargs["show_announcements"]
coder = Coder.create(**kwargs)
if switch.kwargs.get("show_announcements") is not False:
coder.show_announcements()
def load_slow_imports(): def load_slow_imports():