This commit is contained in:
Paul Gauthier 2023-06-05 19:36:12 -07:00
parent 0be54922da
commit 2f0d9279d2
2 changed files with 57 additions and 13 deletions

View file

@ -14,12 +14,12 @@ from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from aider import prompts, utils
from aider import diffs, prompts, utils
from aider.commands import Commands
from aider.repomap import RepoMap
from aider.utils import Models
# from .dump import dump
from .dump import dump # noqa: F401
class MissingAPIKeyError(ValueError):
@ -462,7 +462,11 @@ class Coder:
continue
if self.pretty:
md = Markdown(self.resp, style="blue", code_theme="default")
if self.main_model == Models.GPT35:
show_resp = self.update_files_gpt35(self.resp, just_diffs=True)
else:
show_resp = self.resp
md = Markdown(show_resp, style="blue", code_theme="default")
live.update(md)
else:
sys.stdout.write(text)
@ -471,13 +475,16 @@ class Coder:
if live:
live.stop()
def update_files_gpt35(self, content):
def update_files_gpt35(self, content, just_diffs=False):
edited = set()
chat_files = self.get_inchat_relative_files()
if not chat_files:
if just_diffs:
return content
return
lines = content.splitlines()
output = []
lines = content.splitlines(keepends=True)
fname = None
new_lines = []
for i, line in enumerate(lines):
@ -485,10 +492,21 @@ class Coder:
if fname:
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
new_lines = "\n".join(new_lines) + "\n"
Path(full_path).write_text(new_lines)
edited.add(fname)
if just_diffs:
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
new_lines,
).splitlines()
# dump(show_diff)
output += show_diff
else:
new_lines = "\n".join(new_lines) + "\n"
Path(full_path).write_text(new_lines)
edited.add(fname)
fname = None
new_lines = []
@ -505,6 +523,25 @@ class Coder:
elif fname:
new_lines.append(line)
else:
output.append(line)
if just_diffs:
if fname:
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
if just_diffs:
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
new_lines,
).splitlines()
output += show_diff
return "\n".join(output)
if fname:
raise ValueError("Started a ``` block without closing it")

View file

@ -1,7 +1,7 @@
import difflib
import sys
from .dump import dump
from .dump import dump # noqa: F401
def main():
@ -18,7 +18,8 @@ def main():
lines_updated = f.readlines()
for i in range(len(file_updated)):
diff_partial_update(lines_orig, lines_updated[:i])
res = diff_partial_update(lines_orig, lines_updated[:i])
print(res)
input()
@ -29,8 +30,11 @@ def diff_partial_update(lines_orig, lines_updated):
partially complete update.
"""
# dump(lines_orig)
# dump(lines_updated)
last_non_deleted = find_last_non_deleted(lines_orig, lines_updated)
dump(last_non_deleted)
# dump(last_non_deleted)
if last_non_deleted is None:
return ""
@ -38,10 +42,13 @@ def diff_partial_update(lines_orig, lines_updated):
diff = difflib.unified_diff(lines_orig, lines_updated)
# unified_diff = list(unified_diff)[2:]
# dump(repr(list(diff)))
diff = "".join(diff)
diff = "".join(diff) + "\n"
print(diff)
diff = "```diff\n" + diff + "```\n"
# print(diff)
return diff