mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
refactor to enumerate files to be edited, then make the edits
This commit is contained in:
parent
43047c3835
commit
2455676a44
3 changed files with 42 additions and 31 deletions
|
@ -728,12 +728,10 @@ class Coder:
|
||||||
def get_addable_relative_files(self):
|
def get_addable_relative_files(self):
|
||||||
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
|
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
|
||||||
|
|
||||||
def allowed_to_edit(self, path, write_content=None):
|
def allowed_to_edit(self, path):
|
||||||
full_path = self.abs_root_path(path)
|
full_path = self.abs_root_path(path)
|
||||||
|
|
||||||
if full_path in self.abs_fnames:
|
if full_path in self.abs_fnames:
|
||||||
if write_content:
|
|
||||||
self.io.write_text(full_path, write_content)
|
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
if not Path(full_path).exists():
|
if not Path(full_path).exists():
|
||||||
|
@ -758,18 +756,28 @@ class Coder:
|
||||||
if not self.dry_run:
|
if not self.dry_run:
|
||||||
self.repo.repo.git.add(full_path)
|
self.repo.repo.git.add(full_path)
|
||||||
|
|
||||||
if write_content:
|
|
||||||
self.io.write_text(full_path, write_content)
|
|
||||||
|
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
apply_update_errors = 0
|
apply_update_errors = 0
|
||||||
|
|
||||||
|
def prepare_to_edit(self, edits):
|
||||||
|
res = []
|
||||||
|
for edit in edits:
|
||||||
|
path = edit[0]
|
||||||
|
rest = edit[1:]
|
||||||
|
full_path = self.allowed_to_edit(path)
|
||||||
|
edit = [path, full_path] + list(rest)
|
||||||
|
res.append(edit)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def apply_updates(self):
|
def apply_updates(self):
|
||||||
max_apply_update_errors = 3
|
max_apply_update_errors = 3
|
||||||
|
|
||||||
try:
|
try:
|
||||||
edited = self.update_files()
|
edits = self.get_edits()
|
||||||
|
edits = self.prepare_to_edit(edits)
|
||||||
|
self.apply_edits(edits)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
err = err.args[0]
|
err = err.args[0]
|
||||||
self.apply_update_errors += 1
|
self.apply_update_errors += 1
|
||||||
|
@ -795,8 +803,11 @@ class Coder:
|
||||||
|
|
||||||
self.apply_update_errors = 0
|
self.apply_update_errors = 0
|
||||||
|
|
||||||
if edited:
|
# TODO FIXME: make sure
|
||||||
for path in sorted(edited):
|
edited = set()
|
||||||
|
for edit in sorted(edits):
|
||||||
|
path = edit[0]
|
||||||
|
edited.add(path)
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
self.io.tool_output(f"Did not apply edit to {path} (--dry-run)")
|
self.io.tool_output(f"Did not apply edit to {path} (--dry-run)")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,22 +13,20 @@ class EditBlockCoder(Coder):
|
||||||
self.gpt_prompts = EditBlockPrompts()
|
self.gpt_prompts = EditBlockPrompts()
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def update_files(self):
|
def get_edits(self):
|
||||||
content = self.partial_response_content
|
content = self.partial_response_content
|
||||||
|
|
||||||
# might raise ValueError for malformed ORIG/UPD blocks
|
# might raise ValueError for malformed ORIG/UPD blocks
|
||||||
edits = list(find_original_update_blocks(content))
|
edits = list(find_original_update_blocks(content))
|
||||||
|
|
||||||
edited = set()
|
return edits
|
||||||
for path, original, updated in edits:
|
|
||||||
full_path = self.allowed_to_edit(path)
|
def apply_edits(self, edits):
|
||||||
if not full_path:
|
for path, full_path, original, updated in edits:
|
||||||
continue
|
|
||||||
content = self.io.read_text(full_path)
|
content = self.io.read_text(full_path)
|
||||||
content = do_replace(full_path, content, original, updated)
|
content = do_replace(full_path, content, original, updated)
|
||||||
if content:
|
if content:
|
||||||
self.io.write_text(full_path, content)
|
self.io.write_text(full_path, content)
|
||||||
edited.add(path)
|
|
||||||
continue
|
continue
|
||||||
raise ValueError(f"""InvalidEditBlock: edit failed!
|
raise ValueError(f"""InvalidEditBlock: edit failed!
|
||||||
|
|
||||||
|
@ -42,8 +40,6 @@ The HEAD block needs to be EXACTLY the same as the lines in {path} with nothing
|
||||||
{original}```
|
{original}```
|
||||||
""")
|
""")
|
||||||
|
|
||||||
return edited
|
|
||||||
|
|
||||||
|
|
||||||
def prep(content):
|
def prep(content):
|
||||||
if content and not content.endswith("\n"):
|
if content and not content.endswith("\n"):
|
||||||
|
|
|
@ -22,11 +22,11 @@ class WholeFileCoder(Coder):
|
||||||
|
|
||||||
def render_incremental_response(self, final):
|
def render_incremental_response(self, final):
|
||||||
try:
|
try:
|
||||||
return self.update_files(mode="diff")
|
return self.get_edits(mode="diff")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return self.partial_response_content
|
return self.partial_response_content
|
||||||
|
|
||||||
def update_files(self, mode="update"):
|
def get_edits(self, mode="update"):
|
||||||
content = self.partial_response_content
|
content = self.partial_response_content
|
||||||
|
|
||||||
chat_files = self.get_inchat_relative_files()
|
chat_files = self.get_inchat_relative_files()
|
||||||
|
@ -104,22 +104,26 @@ class WholeFileCoder(Coder):
|
||||||
if fname:
|
if fname:
|
||||||
edits.append((fname, fname_source, new_lines))
|
edits.append((fname, fname_source, new_lines))
|
||||||
|
|
||||||
edited = set()
|
seen = set()
|
||||||
|
refined_edits = []
|
||||||
# process from most reliable filename, to least reliable
|
# process from most reliable filename, to least reliable
|
||||||
for source in ("block", "saw", "chat"):
|
for source in ("block", "saw", "chat"):
|
||||||
for fname, fname_source, new_lines in edits:
|
for fname, fname_source, new_lines in edits:
|
||||||
if fname_source != source:
|
if fname_source != source:
|
||||||
continue
|
continue
|
||||||
# if a higher priority source already edited the file, skip
|
# if a higher priority source already edited the file, skip
|
||||||
if fname in edited:
|
if fname in seen:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# we have a winner
|
seen.add(fname)
|
||||||
new_lines = "".join(new_lines)
|
refined_edits.append((fname, fname_source, new_lines))
|
||||||
if self.allowed_to_edit(fname, new_lines):
|
|
||||||
edited.add(fname)
|
|
||||||
|
|
||||||
return edited
|
return refined_edits
|
||||||
|
|
||||||
|
def apply_edits(self, edits):
|
||||||
|
for path, full_path, fname_source, new_lines in edits:
|
||||||
|
new_lines = "".join(new_lines)
|
||||||
|
self.io.write_text(full_path, new_lines)
|
||||||
|
|
||||||
def do_live_diff(self, full_path, new_lines, final):
|
def do_live_diff(self, full_path, new_lines, final):
|
||||||
if full_path.exists():
|
if full_path.exists():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue