use index.entries

This commit is contained in:
Paul Gauthier 2023-07-24 13:47:33 -03:00
parent df096272bc
commit 66da82094d
2 changed files with 13 additions and 15 deletions

View file

@ -157,8 +157,8 @@ class GitRepo:
# Add staged files
index = self.repo.index
staged_files = [str(Path(PurePosixPath(path))) for path in index.entries.keys()]
dump(staged_files)
staged_files = [path for path, _ in index.entries.keys()]
files.extend(staged_files)
# convert to appropriate os.sep, since git always normalizes to /

View file

@ -2,7 +2,7 @@ import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch, MagicMock
from unittest.mock import patch
import git
@ -11,6 +11,7 @@ from aider.io import InputOutput
from aider.repo import GitRepo
from tests.utils import GitTemporaryDirectory
class TestRepo(unittest.TestCase):
@patch("aider.repo.simple_send_with_retries")
def test_get_commit_message(self, mock_send):
@ -82,35 +83,32 @@ class TestRepo(unittest.TestCase):
self.assertEqual(set(tracked_files), set(created_files))
def test_get_tracked_files_with_new_staged_file(self):
# Mock the IO object
mock_io = MagicMock()
with GitTemporaryDirectory():
# new repo
repo = git.Repo()
raw_repo = git.Repo()
# add it, but no commits at all in the repo yet
# add it, but no commits at all in the raw_repo yet
fname = Path("new.txt")
fname.touch()
repo.git.add(str(fname))
raw_repo.git.add(str(fname))
coder = GitRepo(InputOutput(), None)
git_repo = GitRepo(InputOutput(), None)
# better be there
fnames = coder.get_tracked_files()
fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames)
# commit it, better still be there
repo.git.commit("-m", "new")
fnames = coder.get_tracked_files()
raw_repo.git.commit("-m", "new")
fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames)
# new file, added but not committed
fname2 = Path("new2.txt")
fname2.touch()
repo.git.add(str(fname2))
raw_repo.git.add(str(fname2))
# both should be there
fnames = coder.get_tracked_files()
fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames)
self.assertIn(str(fname2), fnames)