Refactored the cmd_undo method to handle dirty files not in the last commit.

This commit is contained in:
Paul Gauthier 2024-01-02 09:04:57 -08:00
parent 695299be3d
commit e09a2033e2
2 changed files with 17 additions and 18 deletions

View file

@ -187,13 +187,9 @@ class Commands:
return return
last_commit = self.coder.repo.repo.head.commit last_commit = self.coder.repo.repo.head.commit
dump(last_commit)
changed_files_last_commit = {item.a_path for item in last_commit.diff(last_commit.parents[0])} changed_files_last_commit = {item.a_path for item in last_commit.diff(last_commit.parents[0])}
dirty_files = [item.a_path for item in self.coder.repo.repo.index.diff(None)] dirty_files = [item.a_path for item in self.coder.repo.repo.index.diff(None)]
dirty_files_in_last_commit = changed_files_last_commit.intersection(dirty_files) dirty_files_in_last_commit = changed_files_last_commit.intersection(dirty_files)
dump(changed_files_last_commit)
dump(dirty_files)
dump(dirty_files_in_last_commit)
if dirty_files_in_last_commit: if dirty_files_in_last_commit:
self.io.tool_error( self.io.tool_error(

View file

@ -492,6 +492,7 @@ class TestCommands(TestCase):
commands.cmd_drop(str(fname)) commands.cmd_drop(str(fname))
self.assertEqual(len(coder.abs_fnames), 0) self.assertEqual(len(coder.abs_fnames), 0)
def test_cmd_undo_with_dirty_files_not_in_last_commit(self): def test_cmd_undo_with_dirty_files_not_in_last_commit(self):
with GitTemporaryDirectory() as repo_dir: with GitTemporaryDirectory() as repo_dir:
repo = git.Repo(repo_dir) repo = git.Repo(repo_dir)
@ -499,34 +500,36 @@ class TestCommands(TestCase):
coder = Coder.create(models.GPT35, None, io) coder = Coder.create(models.GPT35, None, io)
commands = Commands(io, coder) commands = Commands(io, coder)
other_path = Path(repo_dir) / "other_file.txt"
other_path.write_text("other content")
repo.git.add(str(other_path))
# Create and commit a file # Create and commit a file
filename = "test_file.txt" filename = "test_file.txt"
file_path = Path(repo_dir) / filename file_path = Path(repo_dir) / filename
file_path.write_text("Initial content") file_path.write_text("first content")
repo.git.add(filename) repo.git.add(filename)
repo.git.commit("-m", "Initial commit") repo.git.commit("-m", "aider: first commit")
# Modify the file and commit again file_path.write_text("second content")
file_path.write_text("Modified content")
repo.git.add(filename) repo.git.add(filename)
repo.git.commit("-m", "aider: Modify test_file.txt") repo.git.commit("-m", "aider: second commit")
# Store the commit hash # Store the commit hash
last_commit_hash = repo.head.commit.hexsha[:7] last_commit_hash = repo.head.commit.hexsha[:7]
coder.last_aider_commit_hash = last_commit_hash coder.last_aider_commit_hash = last_commit_hash
# Create a dirty file that was not in the last commit file_path.write_text("dirty content")
other_file = "other_file.txt"
other_file_path = Path(repo_dir) / other_file
other_file_path.write_text("This is an untracked file")
# Attempt to undo the last commit # Attempt to undo the last commit
output = commands.cmd_undo("") commands.cmd_undo("")
self.assertIsNone(output, "Undo should not produce any output")
# Check that the last commit is still present # Check that the last commit is still present
self.assertEqual(last_commit_hash, repo.head.commit.hexsha[:7]) self.assertEqual(last_commit_hash, repo.head.commit.hexsha[:7])
# Check that the dirty file is still untracked # Put back the initial content (so it's not dirty now)
self.assertTrue(other_file_path.exists()) file_path.write_text("second content")
self.assertIn(other_file, repo.untracked_files) other_path.write_text("dirty content")
commands.cmd_undo("")
self.assertNotEqual(last_commit_hash, repo.head.commit.hexsha[:7])