keep markdown stream open across multi response content

This commit is contained in:
Paul Gauthier 2024-06-28 15:10:20 -07:00
parent a3fe3c4dcf
commit d6467a8e30
2 changed files with 28 additions and 21 deletions

View file

@ -811,7 +811,13 @@ class Coder:
if self.verbose: if self.verbose:
utils.show_messages(messages, functions=self.functions) utils.show_messages(messages, functions=self.functions)
multi_response_content = "" self.multi_response_content = ""
if self.show_pretty():
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
self.mdstream = MarkdownStream(mdargs=mdargs)
else:
self.mdstream = None
exhausted = False exhausted = False
interrupted = False interrupted = False
while True: while True:
@ -835,19 +841,19 @@ class Coder:
break break
# Use prefill to continue the response # Use prefill to continue the response
multi_response_content += self.partial_response_content self.multi_response_content += self.partial_response_content
if messages[-1]["role"] == "assistant": if messages[-1]["role"] == "assistant":
messages[-1]["content"] = multi_response_content messages[-1]["content"] = self.multi_response_content
else: else:
messages.append(dict(role="assistant", content=multi_response_content)) messages.append(dict(role="assistant", content=self.multi_response_content))
except Exception as err: except Exception as err:
self.io.tool_error(f"Unexpected error: {err}") self.io.tool_error(f"Unexpected error: {err}")
traceback.print_exc() traceback.print_exc()
return return
if multi_response_content: if self.multi_response_content:
multi_response_content += self.partial_response_content self.multi_response_content += self.partial_response_content
self.partial_response_content = multi_response_content self.partial_response_content = self.multi_response_content
if exhausted: if exhausted:
self.show_exhausted_error() self.show_exhausted_error()
@ -1152,11 +1158,7 @@ class Coder:
raise FinishReasonLength() raise FinishReasonLength()
def show_send_output_stream(self, completion): def show_send_output_stream(self, completion):
if self.show_pretty(): finish_reason_length = False
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
mdstream = MarkdownStream(mdargs=mdargs)
else:
mdstream = None
try: try:
for chunk in completion: for chunk in completion:
@ -1167,6 +1169,8 @@ class Coder:
hasattr(chunk.choices[0], "finish_reason") hasattr(chunk.choices[0], "finish_reason")
and chunk.choices[0].finish_reason == "length" and chunk.choices[0].finish_reason == "length"
): ):
if self.main_model.can_prefill:
finish_reason_length = True
raise FinishReasonLength() raise FinishReasonLength()
try: try:
@ -1188,24 +1192,27 @@ class Coder:
text = None text = None
if self.show_pretty(): if self.show_pretty():
self.live_incremental_response(mdstream, False) self.live_incremental_response(False)
elif text: elif text:
sys.stdout.write(text) sys.stdout.write(text)
sys.stdout.flush() sys.stdout.flush()
yield text yield text
finally: finally:
if mdstream: if self.show_pretty() and not finish_reason_length:
self.live_incremental_response(mdstream, True) self.live_incremental_response(True)
def live_incremental_response(self, mdstream, final): def live_incremental_response(self, final):
show_resp = self.render_incremental_response(final) show_resp = self.render_incremental_response(final)
if not show_resp: if not show_resp:
return return
mdstream.update(show_resp, final=final) self.mdstream.update(show_resp, final=final)
def render_incremental_response(self, final): def render_incremental_response(self, final):
return self.partial_response_content return self.get_multi_response_content()
def get_multi_response_content(self):
return self.multi_response_content + self.partial_response_content
def get_rel_fname(self, fname): def get_rel_fname(self, fname):
return os.path.relpath(fname, self.root) return os.path.relpath(fname, self.root)

View file

@ -1,6 +1,6 @@
from pathlib import Path
from aider import diffs from aider import diffs
from pathlib import Path
from ..dump import dump # noqa: F401 from ..dump import dump # noqa: F401
from .base_coder import Coder from .base_coder import Coder
@ -26,10 +26,10 @@ class WholeFileCoder(Coder):
try: try:
return self.get_edits(mode="diff") return self.get_edits(mode="diff")
except ValueError: except ValueError:
return self.partial_response_content return self.get_multi_response_content()
def get_edits(self, mode="update"): def get_edits(self, mode="update"):
content = self.partial_response_content content = self.get_multi_response_content()
chat_files = self.get_inchat_relative_files() chat_files = self.get_inchat_relative_files()