better live output of multi file edits

This commit is contained in:
Paul Gauthier 2023-06-21 21:02:07 -07:00
parent f187cf0346
commit f9086e66d3
3 changed files with 27 additions and 14 deletions

View file

@ -527,19 +527,23 @@ class Coder:
continue
if self.pretty:
show_resp = self.modify_incremental_response()
if show_resp:
md = Markdown(
show_resp, style=self.assistant_output_color, code_theme="default"
)
live.update(md)
self.live_incremental_response(live, False)
else:
sys.stdout.write(text)
sys.stdout.flush()
finally:
self.live_incremental_response(live, True)
if live:
live.stop()
def live_incremental_response(self, live, final):
show_resp = self.modify_incremental_response(final)
if not show_resp:
return
md = Markdown(show_resp, style=self.assistant_output_color, code_theme="default")
live.update(md)
def modify_incremental_response(self):
return self.partial_response_content

View file

@ -57,7 +57,7 @@ class FunctionCoder(Coder):
else:
self.cur_messages += [dict(role="assistant", content=content)]
def modify_incremental_response(self):
def modify_incremental_response(self, final=False):
args = self.parse_partial_args()
if not args:
@ -70,7 +70,7 @@ class FunctionCoder(Coder):
if explanation:
res += f"{explanation}\n\n"
for file_upd in files:
for i, file_upd in enumerate(files):
path = file_upd.get("path")
if not path:
continue
@ -79,11 +79,13 @@ class FunctionCoder(Coder):
continue
res += path + ":\n"
res += self.live_diffs(path, content)
this_final = (i < len(files) - 1) or final
res += self.live_diffs(path, content, this_final)
return res
def live_diffs(self, fname, content):
def live_diffs(self, fname, content, final):
lines = content.splitlines(keepends=True)
# ending an existing block
@ -95,7 +97,7 @@ class FunctionCoder(Coder):
show_diff = diffs.diff_partial_update(
orig_lines,
lines,
final=True,
final,
).splitlines()
return "\n".join(show_diff)

View file

@ -43,12 +43,17 @@ def diff_partial_update(lines_orig, lines_updated, final=False):
# dump(lines_orig)
# dump(lines_updated)
last_non_deleted = find_last_non_deleted(lines_orig, lines_updated)
num_orig_lines = len(lines_orig)
if final:
last_non_deleted = num_orig_lines
else:
last_non_deleted = find_last_non_deleted(lines_orig, lines_updated)
# dump(last_non_deleted)
if last_non_deleted is None:
return ""
num_orig_lines = len(lines_orig)
pct = last_non_deleted * 100 / num_orig_lines
bar = create_progress_bar(pct)
bar = f"! {last_non_deleted:3d} / {num_orig_lines:3d} lines [{bar}] {pct:3.0f}%\n\n"
@ -63,12 +68,14 @@ def diff_partial_update(lines_orig, lines_updated, final=False):
diff = list(diff)[2:]
diff = "".join(diff)
if not diff.endswith("\n"):
diff += "\n"
show = "```diff\n"
if not final:
show += bar
show += diff + "```\n"
show += diff + "```\n\n"
# print(diff)