From 3ce3799f8dab8a7106dba0a4dd2f406dee5f48fa Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Thu, 18 Jul 2024 16:32:45 +0100 Subject: [PATCH] Implemented checks to ensure files with uncommitted changes or not present in previous commit cannot be undone safely. --- aider/commands.py | 27 ++++++++++++++++++--------- tests/basic/test_commands.py | 7 ++++--- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/aider/commands.py b/aider/commands.py index 702deaf70..5c016381a 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -306,16 +306,24 @@ class Commands: return last_commit = self.coder.repo.repo.head.commit - changed_files_last_commit = [ - item.a_path for item in last_commit.diff(last_commit.parents[0]) - ] + prev_commit = last_commit.parents[0] + changed_files_last_commit = [item.a_path for item in last_commit.diff(prev_commit)] - if any(self.coder.repo.repo.is_dirty(path=fname) for fname in changed_files_last_commit): - self.io.tool_error( - "The repository has uncommitted changes in files that were modified in the last" - " commit. Please commit or stash them before undoing." - ) - return + for fname in changed_files_last_commit: + if self.coder.repo.repo.is_dirty(path=fname): + self.io.tool_error( + f"The file {fname} has uncommitted changes. Please stash them before undoing." + ) + return + + # Check if the file was in the repo in the previous commit + try: + prev_commit.tree[fname] + except KeyError: + self.io.tool_error( + f"The file {fname} was not in the repository in the previous commit. Cannot undo safely." + ) + return local_head = self.coder.repo.repo.git.rev_parse("HEAD") current_branch = self.coder.repo.repo.active_branch.name @@ -346,6 +354,7 @@ class Commands: # Reset only the files which are part of `last_commit` for file_path in changed_files_last_commit: self.coder.repo.repo.git.checkout("HEAD~1", file_path) + # Move the HEAD back before the latest commit self.coder.repo.repo.git.reset("--soft", "HEAD~1") diff --git a/tests/basic/test_commands.py b/tests/basic/test_commands.py index fd091de48..f0be96232 100644 --- a/tests/basic/test_commands.py +++ b/tests/basic/test_commands.py @@ -585,11 +585,12 @@ class TestCommands(TestCase): last_commit_hash = repo.head.commit.hexsha[:7] coder.aider_commit_hashes.add(last_commit_hash) - # Attempt to undo the last commit + # Attempt to undo the last commit, should refuse commands.cmd_undo("") - # Check that the last commit was undone - self.assertNotEqual(last_commit_hash, repo.head.commit.hexsha[:7]) + # Check that the last commit was not undone + self.assertEqual(last_commit_hash, repo.head.commit.hexsha[:7]) + self.assertTrue(file_path.exists()) del coder del commands