diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 9896705d8..aa46e3019 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -48,7 +48,6 @@ class Coder: total_cost = 0.0 num_exhausted_context_windows = 0 last_keyboard_interrupt = None - need_commit_before_edits = False @classmethod def create( @@ -106,6 +105,7 @@ class Coder: self.chat_completion_call_hashes = [] self.chat_completion_response_hashes = [] + self.need_commit_before_edits = set() self.verbose = verbose self.abs_fnames = set() @@ -720,8 +720,7 @@ class Coder: return self.io.tool_output(f"Committing {path} before applying edits.") - self.repo.repo.git.add(path) - self.need_commit_before_edits = True + self.need_commit_before_edits.add(path) return def allowed_to_edit(self, path): @@ -772,7 +771,7 @@ class Coder: res = [] seen = dict() - self.need_commit_before_edits = False + self.need_commit_before_edits = set() for edit in edits: path = edit[0] @@ -785,9 +784,8 @@ class Coder: if allowed: res.append(edit) - fnames = [edit[0] for edit in res] - self.dirty_commit(fnames) - self.need_commit_before_edits = False + self.dirty_commit() + self.need_commit_before_edits = set() return res @@ -886,14 +884,14 @@ class Coder: self.io.tool_output("No changes made to git tracked files.") return self.gpt_prompts.files_content_gpt_no_edits - def dirty_commit(self, fnames): + def dirty_commit(self): if not self.need_commit_before_edits: return if not self.dirty_commits: return if not self.repo: return - self.repo.commit(fnames=fnames) + self.repo.commit(fnames=self.need_commit_before_edits) # files changed, move cur messages back behind the files messages self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) diff --git a/tests/test_coder.py b/tests/test_coder.py index e782af36e..b375bd176 100644 --- a/tests/test_coder.py +++ b/tests/test_coder.py @@ -227,7 +227,6 @@ class TestCoder(unittest.TestCase): mock.return_value = set([str(fname)]) coder.repo.get_tracked_files = mock - dump(fname) # Call the check_for_file_mentions method coder.check_for_file_mentions(f"Please check `{fname}`") @@ -524,14 +523,12 @@ three self.assertEqual(num_commits, 3) diff = repo.git.diff(["HEAD~2", "HEAD~1"]) - dump(diff) self.assertIn("one", diff) self.assertIn("two", diff) self.assertNotIn("three", diff) self.assertNotIn("other", diff) self.assertNotIn("OTHER", diff) - dump(saved_diffs) diff = saved_diffs[0] self.assertIn("one", diff) self.assertIn("two", diff) @@ -553,7 +550,6 @@ three self.assertNotIn("other", diff) self.assertNotIn("OTHER", diff) - dump(saved_diffs) self.assertEqual(len(saved_diffs), 2)