standardize on abs_root_path(), simplify get/apply_edit

This commit is contained in:
Paul Gauthier 2023-08-18 07:20:15 -07:00
parent f45bcbf8eb
commit e608a351f0
4 changed files with 19 additions and 17 deletions

View file

@ -715,7 +715,7 @@ class Coder:
if not self.repo.is_dirty(path): if not self.repo.is_dirty(path):
return return
fullp = self.repo.full_path(path) fullp = Path(self.abs_root_path(path))
if not fullp.stat().st_size: if not fullp.stat().st_size:
return return
@ -731,7 +731,7 @@ class Coder:
if full_path in self.abs_fnames: if full_path in self.abs_fnames:
self.check_for_dirty_commit(path) self.check_for_dirty_commit(path)
return full_path return True
if not Path(full_path).exists(): if not Path(full_path).exists():
if not self.io.confirm_ask(f"Allow creation of new file {path}?"): if not self.io.confirm_ask(f"Allow creation of new file {path}?"):
@ -746,7 +746,7 @@ class Coder:
self.repo.repo.git.add(full_path) self.repo.repo.git.add(full_path)
self.abs_fnames.add(full_path) self.abs_fnames.add(full_path)
return full_path return True
if not self.io.confirm_ask( if not self.io.confirm_ask(
f"Allow edits to {path} which was not previously added to chat?" f"Allow edits to {path} which was not previously added to chat?"
@ -759,7 +759,7 @@ class Coder:
self.abs_fnames.add(full_path) self.abs_fnames.add(full_path)
self.check_for_dirty_commit(path) self.check_for_dirty_commit(path)
return full_path return True
apply_update_errors = 0 apply_update_errors = 0
@ -771,15 +771,14 @@ class Coder:
for edit in edits: for edit in edits:
path = edit[0] path = edit[0]
rest = edit[1:]
if path in seen: if path in seen:
full_path = seen[path] allowed = seen[path]
else: else:
full_path = self.allowed_to_edit(path) allowed = self.allowed_to_edit(path)
seen[path] = full_path seen[path] = allowed
edit = [path, full_path] + list(rest) if allowed:
res.append(edit) res.append(edit)
self.dirty_commit() self.dirty_commit()
self.need_commit_before_edits = False self.need_commit_before_edits = False

View file

@ -22,7 +22,8 @@ class EditBlockCoder(Coder):
return edits return edits
def apply_edits(self, edits): def apply_edits(self, edits):
for path, full_path, original, updated in edits: for path, original, updated in edits:
full_path = self.abs_root_path(path)
content = self.io.read_text(full_path) content = self.io.read_text(full_path)
content = do_replace(full_path, content, original, updated) content = do_replace(full_path, content, original, updated)
if content: if content:

View file

@ -46,7 +46,7 @@ class WholeFileCoder(Coder):
# ending an existing block # ending an existing block
saw_fname = None saw_fname = None
full_path = (Path(self.root) / fname).absolute() full_path = self.abs_root_path(fname)
if mode == "diff": if mode == "diff":
output += self.do_live_diff(full_path, new_lines, True) output += self.do_live_diff(full_path, new_lines, True)
@ -121,12 +121,13 @@ class WholeFileCoder(Coder):
return refined_edits return refined_edits
def apply_edits(self, edits): def apply_edits(self, edits):
for path, full_path, fname_source, new_lines in edits: for path, fname_source, new_lines in edits:
full_path = self.abs_root_path(path)
new_lines = "".join(new_lines) new_lines = "".join(new_lines)
self.io.write_text(full_path, new_lines) self.io.write_text(full_path, new_lines)
def do_live_diff(self, full_path, new_lines, final): def do_live_diff(self, full_path, new_lines, final):
if full_path.exists(): if Path(full_path).exists():
orig_lines = self.io.read_text(full_path).splitlines(keepends=True) orig_lines = self.io.read_text(full_path).splitlines(keepends=True)
show_diff = diffs.diff_partial_update( show_diff = diffs.diff_partial_update(

View file

@ -74,7 +74,7 @@ class GitRepo:
cmd = ["-m", full_commit_message, "--no-verify"] cmd = ["-m", full_commit_message, "--no-verify"]
if fnames: if fnames:
fnames = [str(self.full_path(fn)) for fn in fnames] fnames = [str(self.abs_root_path(fn)) for fn in fnames]
for fname in fnames: for fname in fnames:
self.repo.git.add(fname) self.repo.git.add(fname)
cmd += ["--"] + fnames cmd += ["--"] + fnames
@ -199,8 +199,9 @@ class GitRepo:
tracked_files = set(self.get_tracked_files()) tracked_files = set(self.get_tracked_files())
return path in tracked_files return path in tracked_files
def full_path(self, path): def abs_root_path(self, path):
return (Path(self.root) / path).resolve() res = Path(self.root) / path
return utils.safe_abs_path(res)
def is_dirty(self, path=None): def is_dirty(self, path=None):
if path and not self.path_in_repo(path): if path and not self.path_in_repo(path):