keep markdown stream open across multi response content

This commit is contained in:
Paul Gauthier 2024-06-28 15:10:20 -07:00
parent a3fe3c4dcf
commit d6467a8e30
2 changed files with 28 additions and 21 deletions

View file

@ -811,7 +811,13 @@ class Coder:
if self.verbose:
utils.show_messages(messages, functions=self.functions)
multi_response_content = ""
self.multi_response_content = ""
if self.show_pretty():
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
self.mdstream = MarkdownStream(mdargs=mdargs)
else:
self.mdstream = None
exhausted = False
interrupted = False
while True:
@ -835,19 +841,19 @@ class Coder:
break
# Use prefill to continue the response
multi_response_content += self.partial_response_content
self.multi_response_content += self.partial_response_content
if messages[-1]["role"] == "assistant":
messages[-1]["content"] = multi_response_content
messages[-1]["content"] = self.multi_response_content
else:
messages.append(dict(role="assistant", content=multi_response_content))
messages.append(dict(role="assistant", content=self.multi_response_content))
except Exception as err:
self.io.tool_error(f"Unexpected error: {err}")
traceback.print_exc()
return
if multi_response_content:
multi_response_content += self.partial_response_content
self.partial_response_content = multi_response_content
if self.multi_response_content:
self.multi_response_content += self.partial_response_content
self.partial_response_content = self.multi_response_content
if exhausted:
self.show_exhausted_error()
@ -1152,11 +1158,7 @@ class Coder:
raise FinishReasonLength()
def show_send_output_stream(self, completion):
if self.show_pretty():
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
mdstream = MarkdownStream(mdargs=mdargs)
else:
mdstream = None
finish_reason_length = False
try:
for chunk in completion:
@ -1167,6 +1169,8 @@ class Coder:
hasattr(chunk.choices[0], "finish_reason")
and chunk.choices[0].finish_reason == "length"
):
if self.main_model.can_prefill:
finish_reason_length = True
raise FinishReasonLength()
try:
@ -1188,24 +1192,27 @@ class Coder:
text = None
if self.show_pretty():
self.live_incremental_response(mdstream, False)
self.live_incremental_response(False)
elif text:
sys.stdout.write(text)
sys.stdout.flush()
yield text
finally:
if mdstream:
self.live_incremental_response(mdstream, True)
if self.show_pretty() and not finish_reason_length:
self.live_incremental_response(True)
def live_incremental_response(self, mdstream, final):
def live_incremental_response(self, final):
show_resp = self.render_incremental_response(final)
if not show_resp:
return
mdstream.update(show_resp, final=final)
self.mdstream.update(show_resp, final=final)
def render_incremental_response(self, final):
return self.partial_response_content
return self.get_multi_response_content()
def get_multi_response_content(self):
return self.multi_response_content + self.partial_response_content
def get_rel_fname(self, fname):
return os.path.relpath(fname, self.root)