refactor to enumerate files to be edited, then make the edits

This commit is contained in:
Paul Gauthier 2023-08-17 10:07:22 -07:00
parent 43047c3835
commit 2455676a44
3 changed files with 42 additions and 31 deletions

View file

@ -728,12 +728,10 @@ class Coder:
def get_addable_relative_files(self): def get_addable_relative_files(self):
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files()) return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
def allowed_to_edit(self, path, write_content=None): def allowed_to_edit(self, path):
full_path = self.abs_root_path(path) full_path = self.abs_root_path(path)
if full_path in self.abs_fnames: if full_path in self.abs_fnames:
if write_content:
self.io.write_text(full_path, write_content)
return full_path return full_path
if not Path(full_path).exists(): if not Path(full_path).exists():
@ -758,18 +756,28 @@ class Coder:
if not self.dry_run: if not self.dry_run:
self.repo.repo.git.add(full_path) self.repo.repo.git.add(full_path)
if write_content:
self.io.write_text(full_path, write_content)
return full_path return full_path
apply_update_errors = 0 apply_update_errors = 0
def prepare_to_edit(self, edits):
res = []
for edit in edits:
path = edit[0]
rest = edit[1:]
full_path = self.allowed_to_edit(path)
edit = [path, full_path] + list(rest)
res.append(edit)
return res
def apply_updates(self): def apply_updates(self):
max_apply_update_errors = 3 max_apply_update_errors = 3
try: try:
edited = self.update_files() edits = self.get_edits()
edits = self.prepare_to_edit(edits)
self.apply_edits(edits)
except ValueError as err: except ValueError as err:
err = err.args[0] err = err.args[0]
self.apply_update_errors += 1 self.apply_update_errors += 1
@ -795,12 +803,15 @@ class Coder:
self.apply_update_errors = 0 self.apply_update_errors = 0
if edited: # TODO FIXME: make sure
for path in sorted(edited): edited = set()
if self.dry_run: for edit in sorted(edits):
self.io.tool_output(f"Did not apply edit to {path} (--dry-run)") path = edit[0]
else: edited.add(path)
self.io.tool_output(f"Applied edit to {path}") if self.dry_run:
self.io.tool_output(f"Did not apply edit to {path} (--dry-run)")
else:
self.io.tool_output(f"Applied edit to {path}")
return edited, None return edited, None

View file

@ -13,22 +13,20 @@ class EditBlockCoder(Coder):
self.gpt_prompts = EditBlockPrompts() self.gpt_prompts = EditBlockPrompts()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def update_files(self): def get_edits(self):
content = self.partial_response_content content = self.partial_response_content
# might raise ValueError for malformed ORIG/UPD blocks # might raise ValueError for malformed ORIG/UPD blocks
edits = list(find_original_update_blocks(content)) edits = list(find_original_update_blocks(content))
edited = set() return edits
for path, original, updated in edits:
full_path = self.allowed_to_edit(path) def apply_edits(self, edits):
if not full_path: for path, full_path, original, updated in edits:
continue
content = self.io.read_text(full_path) content = self.io.read_text(full_path)
content = do_replace(full_path, content, original, updated) content = do_replace(full_path, content, original, updated)
if content: if content:
self.io.write_text(full_path, content) self.io.write_text(full_path, content)
edited.add(path)
continue continue
raise ValueError(f"""InvalidEditBlock: edit failed! raise ValueError(f"""InvalidEditBlock: edit failed!
@ -42,8 +40,6 @@ The HEAD block needs to be EXACTLY the same as the lines in {path} with nothing
{original}``` {original}```
""") """)
return edited
def prep(content): def prep(content):
if content and not content.endswith("\n"): if content and not content.endswith("\n"):

View file

@ -22,11 +22,11 @@ class WholeFileCoder(Coder):
def render_incremental_response(self, final): def render_incremental_response(self, final):
try: try:
return self.update_files(mode="diff") return self.get_edits(mode="diff")
except ValueError: except ValueError:
return self.partial_response_content return self.partial_response_content
def update_files(self, mode="update"): def get_edits(self, mode="update"):
content = self.partial_response_content content = self.partial_response_content
chat_files = self.get_inchat_relative_files() chat_files = self.get_inchat_relative_files()
@ -104,22 +104,26 @@ class WholeFileCoder(Coder):
if fname: if fname:
edits.append((fname, fname_source, new_lines)) edits.append((fname, fname_source, new_lines))
edited = set() seen = set()
refined_edits = []
# process from most reliable filename, to least reliable # process from most reliable filename, to least reliable
for source in ("block", "saw", "chat"): for source in ("block", "saw", "chat"):
for fname, fname_source, new_lines in edits: for fname, fname_source, new_lines in edits:
if fname_source != source: if fname_source != source:
continue continue
# if a higher priority source already edited the file, skip # if a higher priority source already edited the file, skip
if fname in edited: if fname in seen:
continue continue
# we have a winner seen.add(fname)
new_lines = "".join(new_lines) refined_edits.append((fname, fname_source, new_lines))
if self.allowed_to_edit(fname, new_lines):
edited.add(fname)
return edited return refined_edits
def apply_edits(self, edits):
for path, full_path, fname_source, new_lines in edits:
new_lines = "".join(new_lines)
self.io.write_text(full_path, new_lines)
def do_live_diff(self, full_path, new_lines, final): def do_live_diff(self, full_path, new_lines, final):
if full_path.exists(): if full_path.exists():