diff --git a/aider/commands.py b/aider/commands.py index 0f3853d66..d8a90be39 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -821,13 +821,23 @@ class Commands: """ user_msg += "\n".join(self.coder.get_announcements()) + "\n" - assistant_msg = coder.run(user_msg, preproc=False) + coder.run(user_msg, preproc=False) - self.coder.cur_messages += [ - dict(role="user", content=user_msg), - dict(role="assistant", content=assistant_msg), - ] - self.coder.total_cost = coder.total_cost + if self.coder.repo_map: + map_tokens = self.coder.repo_map.max_map_tokens + map_mul_no_files = self.coder.repo_map.map_mul_no_files + else: + map_tokens = 0 + map_mul_no_files = 1 + + raise SwitchCoder( + edit_format=self.coder.edit_format, + summarize_from_coder=False, + from_coder=coder, + map_tokens=map_tokens, + map_mul_no_files=map_mul_no_files, + show_announcements=False, + ) def clone(self): return Commands( @@ -846,7 +856,7 @@ class Commands: from aider.coders import Coder - chat_coder = Coder.create( + coder = Coder.create( io=self.io, from_coder=self.coder, edit_format="ask", @@ -854,13 +864,14 @@ class Commands: ) user_msg = args - assistant_msg = chat_coder.run(user_msg) + coder.run(user_msg) - self.coder.cur_messages += [ - dict(role="user", content=user_msg), - dict(role="assistant", content=assistant_msg), - ] - self.coder.total_cost = chat_coder.total_cost + raise SwitchCoder( + edit_format=self.coder.edit_format, + summarize_from_coder=False, + from_coder=coder, + show_announcements=False, + ) def get_help_md(self): "Show help about all commands in markdown" diff --git a/aider/main.py b/aider/main.py index 33994dc39..de50dbc70 100644 --- a/aider/main.py +++ b/aider/main.py @@ -620,8 +620,15 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F coder.run() return except SwitchCoder as switch: - coder = Coder.create(io=io, from_coder=coder, **switch.kwargs) - coder.show_announcements() + kwargs = dict(io=io, from_coder=coder) + kwargs.update(switch.kwargs) + if "show_announcements" in kwargs: + del kwargs["show_announcements"] + + coder = Coder.create(**kwargs) + + if switch.kwargs.get("show_announcements") is not False: + coder.show_announcements() def load_slow_imports():