standardize on abs_root_path(), simplify get/apply_edit

This commit is contained in:
Paul Gauthier 2023-08-18 07:20:15 -07:00
parent f45bcbf8eb
commit e608a351f0
4 changed files with 19 additions and 17 deletions

View file

@ -715,7 +715,7 @@ class Coder:
if not self.repo.is_dirty(path):
return
fullp = self.repo.full_path(path)
fullp = Path(self.abs_root_path(path))
if not fullp.stat().st_size:
return
@ -731,7 +731,7 @@ class Coder:
if full_path in self.abs_fnames:
self.check_for_dirty_commit(path)
return full_path
return True
if not Path(full_path).exists():
if not self.io.confirm_ask(f"Allow creation of new file {path}?"):
@ -746,7 +746,7 @@ class Coder:
self.repo.repo.git.add(full_path)
self.abs_fnames.add(full_path)
return full_path
return True
if not self.io.confirm_ask(
f"Allow edits to {path} which was not previously added to chat?"
@ -759,7 +759,7 @@ class Coder:
self.abs_fnames.add(full_path)
self.check_for_dirty_commit(path)
return full_path
return True
apply_update_errors = 0
@ -771,14 +771,13 @@ class Coder:
for edit in edits:
path = edit[0]
rest = edit[1:]
if path in seen:
full_path = seen[path]
allowed = seen[path]
else:
full_path = self.allowed_to_edit(path)
seen[path] = full_path
allowed = self.allowed_to_edit(path)
seen[path] = allowed
edit = [path, full_path] + list(rest)
if allowed:
res.append(edit)
self.dirty_commit()

View file

@ -22,7 +22,8 @@ class EditBlockCoder(Coder):
return edits
def apply_edits(self, edits):
for path, full_path, original, updated in edits:
for path, original, updated in edits:
full_path = self.abs_root_path(path)
content = self.io.read_text(full_path)
content = do_replace(full_path, content, original, updated)
if content:

View file

@ -46,7 +46,7 @@ class WholeFileCoder(Coder):
# ending an existing block
saw_fname = None
full_path = (Path(self.root) / fname).absolute()
full_path = self.abs_root_path(fname)
if mode == "diff":
output += self.do_live_diff(full_path, new_lines, True)
@ -121,12 +121,13 @@ class WholeFileCoder(Coder):
return refined_edits
def apply_edits(self, edits):
for path, full_path, fname_source, new_lines in edits:
for path, fname_source, new_lines in edits:
full_path = self.abs_root_path(path)
new_lines = "".join(new_lines)
self.io.write_text(full_path, new_lines)
def do_live_diff(self, full_path, new_lines, final):
if full_path.exists():
if Path(full_path).exists():
orig_lines = self.io.read_text(full_path).splitlines(keepends=True)
show_diff = diffs.diff_partial_update(

View file

@ -74,7 +74,7 @@ class GitRepo:
cmd = ["-m", full_commit_message, "--no-verify"]
if fnames:
fnames = [str(self.full_path(fn)) for fn in fnames]
fnames = [str(self.abs_root_path(fn)) for fn in fnames]
for fname in fnames:
self.repo.git.add(fname)
cmd += ["--"] + fnames
@ -199,8 +199,9 @@ class GitRepo:
tracked_files = set(self.get_tracked_files())
return path in tracked_files
def full_path(self, path):
return (Path(self.root) / path).resolve()
def abs_root_path(self, path):
res = Path(self.root) / path
return utils.safe_abs_path(res)
def is_dirty(self, path=None):
if path and not self.path_in_repo(path):