From d6467a8e30388d5f0a2fa9a79a9af805950a20fb Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Fri, 28 Jun 2024 15:10:20 -0700 Subject: [PATCH] keep markdown stream open across multi response content --- aider/coders/base_coder.py | 43 +++++++++++++++++++-------------- aider/coders/wholefile_coder.py | 6 ++--- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index bf82b9ee6..11697d071 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -811,7 +811,13 @@ class Coder: if self.verbose: utils.show_messages(messages, functions=self.functions) - multi_response_content = "" + self.multi_response_content = "" + if self.show_pretty(): + mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme) + self.mdstream = MarkdownStream(mdargs=mdargs) + else: + self.mdstream = None + exhausted = False interrupted = False while True: @@ -835,19 +841,19 @@ class Coder: break # Use prefill to continue the response - multi_response_content += self.partial_response_content + self.multi_response_content += self.partial_response_content if messages[-1]["role"] == "assistant": - messages[-1]["content"] = multi_response_content + messages[-1]["content"] = self.multi_response_content else: - messages.append(dict(role="assistant", content=multi_response_content)) + messages.append(dict(role="assistant", content=self.multi_response_content)) except Exception as err: self.io.tool_error(f"Unexpected error: {err}") traceback.print_exc() return - if multi_response_content: - multi_response_content += self.partial_response_content - self.partial_response_content = multi_response_content + if self.multi_response_content: + self.multi_response_content += self.partial_response_content + self.partial_response_content = self.multi_response_content if exhausted: self.show_exhausted_error() @@ -1152,11 +1158,7 @@ class Coder: raise FinishReasonLength() def show_send_output_stream(self, completion): - if self.show_pretty(): - mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme) - mdstream = MarkdownStream(mdargs=mdargs) - else: - mdstream = None + finish_reason_length = False try: for chunk in completion: @@ -1167,6 +1169,8 @@ class Coder: hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "length" ): + if self.main_model.can_prefill: + finish_reason_length = True raise FinishReasonLength() try: @@ -1188,24 +1192,27 @@ class Coder: text = None if self.show_pretty(): - self.live_incremental_response(mdstream, False) + self.live_incremental_response(False) elif text: sys.stdout.write(text) sys.stdout.flush() yield text finally: - if mdstream: - self.live_incremental_response(mdstream, True) + if self.show_pretty() and not finish_reason_length: + self.live_incremental_response(True) - def live_incremental_response(self, mdstream, final): + def live_incremental_response(self, final): show_resp = self.render_incremental_response(final) if not show_resp: return - mdstream.update(show_resp, final=final) + self.mdstream.update(show_resp, final=final) def render_incremental_response(self, final): - return self.partial_response_content + return self.get_multi_response_content() + + def get_multi_response_content(self): + return self.multi_response_content + self.partial_response_content def get_rel_fname(self, fname): return os.path.relpath(fname, self.root) diff --git a/aider/coders/wholefile_coder.py b/aider/coders/wholefile_coder.py index b9ecc60b5..a4420c462 100644 --- a/aider/coders/wholefile_coder.py +++ b/aider/coders/wholefile_coder.py @@ -1,6 +1,6 @@ +from pathlib import Path from aider import diffs -from pathlib import Path from ..dump import dump # noqa: F401 from .base_coder import Coder @@ -26,10 +26,10 @@ class WholeFileCoder(Coder): try: return self.get_edits(mode="diff") except ValueError: - return self.partial_response_content + return self.get_multi_response_content() def get_edits(self, mode="update"): - content = self.partial_response_content + content = self.get_multi_response_content() chat_files = self.get_inchat_relative_files()