diff --git a/aider/commands.py b/aider/commands.py index 1196abdde..05b8e04d0 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -224,13 +224,11 @@ class Commands: return last_commit = self.coder.repo.repo.head.commit - changed_files_last_commit = { + changed_files_last_commit = [ item.a_path for item in last_commit.diff(last_commit.parents[0]) - } - dirty_files = [item.a_path for item in self.coder.repo.repo.index.diff(None)] - dirty_files_in_last_commit = changed_files_last_commit.intersection(dirty_files) + ] - if dirty_files_in_last_commit: + if any(self.coder.repo.repo.is_dirty(path=fname) for fname in changed_files_last_commit): self.io.tool_error( "The repository has uncommitted changes in files that were modified in the last" " commit. Please commit or stash them before undoing." @@ -264,7 +262,13 @@ class Commands: " command!" ) return - self.coder.repo.repo.git.reset("--hard", "HEAD~1") + + # Reset only the files which are part of `last_commit` + for file_path in changed_files_last_commit: + self.coder.repo.repo.git.checkout("HEAD~1", file_path) + # Move the HEAD back before the latest commit + self.coder.repo.repo.git.reset("--soft", "HEAD~1") + self.io.tool_output( f"{last_commit.message.strip()}\n" f"The above commit {self.coder.last_aider_commit_hash} " diff --git a/tests/test_commands.py b/tests/test_commands.py index 7fd24a4f5..dedd4bc53 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -539,6 +539,9 @@ class TestCommands(TestCase): commands.cmd_undo("") self.assertNotEqual(last_commit_hash, repo.head.commit.hexsha[:7]) + self.assertEqual(file_path.read_text(), "first content") + self.assertEqual(other_path.read_text(), "dirty content") + del coder del commands del repo