diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index fe1415f02..d8f53a2cc 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -728,12 +728,10 @@ class Coder: def get_addable_relative_files(self): return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files()) - def allowed_to_edit(self, path, write_content=None): + def allowed_to_edit(self, path): full_path = self.abs_root_path(path) if full_path in self.abs_fnames: - if write_content: - self.io.write_text(full_path, write_content) return full_path if not Path(full_path).exists(): @@ -758,18 +756,28 @@ class Coder: if not self.dry_run: self.repo.repo.git.add(full_path) - if write_content: - self.io.write_text(full_path, write_content) - return full_path apply_update_errors = 0 + def prepare_to_edit(self, edits): + res = [] + for edit in edits: + path = edit[0] + rest = edit[1:] + full_path = self.allowed_to_edit(path) + edit = [path, full_path] + list(rest) + res.append(edit) + + return res + def apply_updates(self): max_apply_update_errors = 3 try: - edited = self.update_files() + edits = self.get_edits() + edits = self.prepare_to_edit(edits) + self.apply_edits(edits) except ValueError as err: err = err.args[0] self.apply_update_errors += 1 @@ -795,12 +803,15 @@ class Coder: self.apply_update_errors = 0 - if edited: - for path in sorted(edited): - if self.dry_run: - self.io.tool_output(f"Did not apply edit to {path} (--dry-run)") - else: - self.io.tool_output(f"Applied edit to {path}") + # TODO FIXME: make sure + edited = set() + for edit in sorted(edits): + path = edit[0] + edited.add(path) + if self.dry_run: + self.io.tool_output(f"Did not apply edit to {path} (--dry-run)") + else: + self.io.tool_output(f"Applied edit to {path}") return edited, None diff --git a/aider/coders/editblock_coder.py b/aider/coders/editblock_coder.py index 5388bbac7..d461d85c6 100644 --- a/aider/coders/editblock_coder.py +++ b/aider/coders/editblock_coder.py @@ -13,22 +13,20 @@ class EditBlockCoder(Coder): self.gpt_prompts = EditBlockPrompts() super().__init__(*args, **kwargs) - def update_files(self): + def get_edits(self): content = self.partial_response_content # might raise ValueError for malformed ORIG/UPD blocks edits = list(find_original_update_blocks(content)) - edited = set() - for path, original, updated in edits: - full_path = self.allowed_to_edit(path) - if not full_path: - continue + return edits + + def apply_edits(self, edits): + for path, full_path, original, updated in edits: content = self.io.read_text(full_path) content = do_replace(full_path, content, original, updated) if content: self.io.write_text(full_path, content) - edited.add(path) continue raise ValueError(f"""InvalidEditBlock: edit failed! @@ -42,8 +40,6 @@ The HEAD block needs to be EXACTLY the same as the lines in {path} with nothing {original}``` """) - return edited - def prep(content): if content and not content.endswith("\n"): diff --git a/aider/coders/wholefile_coder.py b/aider/coders/wholefile_coder.py index c70576dae..b8c3b24cd 100644 --- a/aider/coders/wholefile_coder.py +++ b/aider/coders/wholefile_coder.py @@ -22,11 +22,11 @@ class WholeFileCoder(Coder): def render_incremental_response(self, final): try: - return self.update_files(mode="diff") + return self.get_edits(mode="diff") except ValueError: return self.partial_response_content - def update_files(self, mode="update"): + def get_edits(self, mode="update"): content = self.partial_response_content chat_files = self.get_inchat_relative_files() @@ -104,22 +104,26 @@ class WholeFileCoder(Coder): if fname: edits.append((fname, fname_source, new_lines)) - edited = set() + seen = set() + refined_edits = [] # process from most reliable filename, to least reliable for source in ("block", "saw", "chat"): for fname, fname_source, new_lines in edits: if fname_source != source: continue # if a higher priority source already edited the file, skip - if fname in edited: + if fname in seen: continue - # we have a winner - new_lines = "".join(new_lines) - if self.allowed_to_edit(fname, new_lines): - edited.add(fname) + seen.add(fname) + refined_edits.append((fname, fname_source, new_lines)) - return edited + return refined_edits + + def apply_edits(self, edits): + for path, full_path, fname_source, new_lines in edits: + new_lines = "".join(new_lines) + self.io.write_text(full_path, new_lines) def do_live_diff(self, full_path, new_lines, final): if full_path.exists():