From 752e47a886831369dac222dfe8e55fe8842c68f8 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Fri, 18 Aug 2023 09:43:40 -0700 Subject: [PATCH] better tests, small cleanups --- aider/coders/base_coder.py | 18 ++++++++----- aider/repo.py | 3 +++ tests/test_coder.py | 55 +++++++++++++++++++++++++++++++++----- 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index af351a42c..9896705d8 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -727,7 +727,9 @@ class Coder: def allowed_to_edit(self, path): full_path = self.abs_root_path(path) if self.repo: - is_in_repo = self.repo.path_in_repo(path) + need_to_add = not self.repo.path_in_repo(path) + else: + need_to_add = False if full_path in self.abs_fnames: self.check_for_dirty_commit(path) @@ -742,7 +744,10 @@ class Coder: Path(full_path).parent.mkdir(parents=True, exist_ok=True) Path(full_path).touch() - if self.repo and not is_in_repo: + # Seems unlikely that we needed to create the file, but it was + # actually already part of the repo. + # But let's handle this obscure corner case anyway. + if need_to_add: self.repo.repo.git.add(full_path) self.abs_fnames.add(full_path) @@ -754,7 +759,7 @@ class Coder: self.io.tool_error(f"Skipping edits to {path}") return - if self.repo and not is_in_repo: + if need_to_add: self.repo.repo.git.add(full_path) self.abs_fnames.add(full_path) @@ -780,7 +785,8 @@ class Coder: if allowed: res.append(edit) - self.dirty_commit() + fnames = [edit[0] for edit in res] + self.dirty_commit(fnames) self.need_commit_before_edits = False return res @@ -880,14 +886,14 @@ class Coder: self.io.tool_output("No changes made to git tracked files.") return self.gpt_prompts.files_content_gpt_no_edits - def dirty_commit(self): + def dirty_commit(self, fnames): if not self.need_commit_before_edits: return if not self.dirty_commits: return if not self.repo: return - self.repo.commit() + self.repo.commit(fnames=fnames) # files changed, move cur messages back behind the files messages self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) diff --git a/aider/repo.py b/aider/repo.py index 2614eab76..4e6f73e4b 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -59,7 +59,9 @@ class GitRepo: diff_args = [] if fnames: diff_args += ["--"] + list(fnames) + dump(diff_args) diffs = self.get_diffs(False, *diff_args) + dump(diffs) commit_message = self.get_commit_message(diffs, context) if not commit_message: @@ -134,6 +136,7 @@ class GitRepo: if args: if pretty: args = ["--color"] + args + dump(args) return self.repo.git.diff(*args) # otherwise, we always want diffs of index and working dir diff --git a/tests/test_coder.py b/tests/test_coder.py index 24bb39c9d..e782af36e 100644 --- a/tests/test_coder.py +++ b/tests/test_coder.py @@ -474,9 +474,16 @@ TWO fname = Path("file.txt") fname.write_text("one\n") repo.git.add(str(fname)) + + fname2 = Path("other.txt") + fname2.write_text("other\n") + repo.git.add(str(fname2)) + repo.git.commit("-m", "new") + # dirty fname.write_text("two\n") + fname2.write_text("OTHER\n") io = InputOutput(yes=True) coder = Coder.create(models.GPT4, "diff", io=io, fnames=[str(fname)]) @@ -495,23 +502,59 @@ three """ coder.partial_response_function_call = dict() + saved_diffs = [] + + def mock_get_commit_message(diffs, context): + saved_diffs.append(diffs) + return "commit message" + + coder.repo.get_commit_message = MagicMock(side_effect=mock_get_commit_message) coder.send = MagicMock(side_effect=mock_send) - coder.repo.get_commit_message = MagicMock() - coder.repo.get_commit_message.return_value = "commit message" coder.run(with_message="hi") + print("=" * 20) + print(repo.git.log(["-p"])) + print("=" * 20) + content = fname.read_text() self.assertEqual(content, "three\n") num_commits = len(list(repo.iter_commits(repo.active_branch.name))) self.assertEqual(num_commits, 3) - self.assertIn("two", repo.git.diff(["HEAD~1", "HEAD"])) - self.assertIn("three", repo.git.diff(["HEAD~1", "HEAD"])) + diff = repo.git.diff(["HEAD~2", "HEAD~1"]) + dump(diff) + self.assertIn("one", diff) + self.assertIn("two", diff) + self.assertNotIn("three", diff) + self.assertNotIn("other", diff) + self.assertNotIn("OTHER", diff) - self.assertIn("one", repo.git.diff(["HEAD~2", "HEAD~1"])) - self.assertIn("two", repo.git.diff(["HEAD~2", "HEAD~1"])) + dump(saved_diffs) + diff = saved_diffs[0] + self.assertIn("one", diff) + self.assertIn("two", diff) + self.assertNotIn("three", diff) + self.assertNotIn("other", diff) + self.assertNotIn("OTHER", diff) + + diff = repo.git.diff(["HEAD~1", "HEAD"]) + self.assertNotIn("one", diff) + self.assertIn("two", diff) + self.assertIn("three", diff) + self.assertNotIn("other", diff) + self.assertNotIn("OTHER", diff) + + diff = saved_diffs[1] + self.assertNotIn("one", diff) + self.assertIn("two", diff) + self.assertIn("three", diff) + self.assertNotIn("other", diff) + self.assertNotIn("OTHER", diff) + + dump(saved_diffs) + self.assertEqual(len(saved_diffs), 2) if __name__ == "__main__":