diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index ebc6d4241..c119815e8 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -717,7 +717,9 @@ class Coder: def get_addable_relative_files(self): return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files()) - def allowed_to_edit(self, path): + def allowed_to_edit(self, path, write_content=None): + # TODO: respect --dry-run + full_path = os.path.abspath(os.path.join(self.root, path)) if full_path in self.abs_fnames: @@ -731,7 +733,7 @@ class Coder: self.io.tool_error(f"Skipping edit to {path}") return - if not Path(full_path).exists(): + if not Path(full_path).exists() and not self.dry_run: Path(full_path).parent.mkdir(parents=True, exist_ok=True) Path(full_path).touch() @@ -742,7 +744,11 @@ class Coder: tracked_files = set(self.repo.git.ls_files().splitlines()) relative_fname = self.get_rel_fname(full_path) if relative_fname not in tracked_files and self.io.confirm_ask(f"Add {path} to git?"): - self.repo.git.add(full_path) + if not self.dry_run: + self.repo.git.add(full_path) + + if not self.dry_run and write_content: + Path(full_path).write_text(write_content) return full_path @@ -763,7 +769,11 @@ class Coder: if edited: for path in sorted(edited): - self.io.tool_output(f"Applied edit to {path}") + if self.dry_run: + self.io.tool_output(f"Did not apply edit to {path} (--dry-run)") + else: + self.io.tool_output(f"Applied edit to {path}") + return edited, None def parse_partial_args(self): diff --git a/aider/coders/wholefile_func_coder.py b/aider/coders/wholefile_func_coder.py index 85b9ae289..8a0cd8f41 100644 --- a/aider/coders/wholefile_func_coder.py +++ b/aider/coders/wholefile_func_coder.py @@ -1,5 +1,4 @@ import os -from pathlib import Path from aider import diffs @@ -111,24 +110,17 @@ class WholeFileFunctionCoder(Coder): files = args.get("files", []) - chat_files = self.get_inchat_relative_files() - edited = set() for file_upd in files: path = file_upd.get("path") if not path: - raise ValueError(f"Missing path: {file_upd}") - - if path not in chat_files: - raise ValueError(f"File {path} not in chat session.") + raise ValueError(f"Missing path parameter: {file_upd}") content = file_upd.get("content") if not content: - raise ValueError(f"Missing content: {file_upd}") + raise ValueError(f"Missing content parameter: {file_upd}") - edited.add(path) - if not self.dry_run: - full_path = os.path.abspath(os.path.join(self.root, path)) - Path(full_path).write_text(content) + if self.allowed_to_edit(path, content): + edited.add(path) return edited