use index.entries

This commit is contained in:
Paul Gauthier 2023-07-24 13:47:33 -03:00
parent df096272bc
commit 66da82094d
2 changed files with 13 additions and 15 deletions

View file

@ -157,8 +157,8 @@ class GitRepo:
# Add staged files # Add staged files
index = self.repo.index index = self.repo.index
staged_files = [str(Path(PurePosixPath(path))) for path in index.entries.keys()] staged_files = [path for path, _ in index.entries.keys()]
dump(staged_files)
files.extend(staged_files) files.extend(staged_files)
# convert to appropriate os.sep, since git always normalizes to / # convert to appropriate os.sep, since git always normalizes to /

View file

@ -2,7 +2,7 @@ import os
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch, MagicMock from unittest.mock import patch
import git import git
@ -11,6 +11,7 @@ from aider.io import InputOutput
from aider.repo import GitRepo from aider.repo import GitRepo
from tests.utils import GitTemporaryDirectory from tests.utils import GitTemporaryDirectory
class TestRepo(unittest.TestCase): class TestRepo(unittest.TestCase):
@patch("aider.repo.simple_send_with_retries") @patch("aider.repo.simple_send_with_retries")
def test_get_commit_message(self, mock_send): def test_get_commit_message(self, mock_send):
@ -82,35 +83,32 @@ class TestRepo(unittest.TestCase):
self.assertEqual(set(tracked_files), set(created_files)) self.assertEqual(set(tracked_files), set(created_files))
def test_get_tracked_files_with_new_staged_file(self): def test_get_tracked_files_with_new_staged_file(self):
# Mock the IO object
mock_io = MagicMock()
with GitTemporaryDirectory(): with GitTemporaryDirectory():
# new repo # new repo
repo = git.Repo() raw_repo = git.Repo()
# add it, but no commits at all in the repo yet # add it, but no commits at all in the raw_repo yet
fname = Path("new.txt") fname = Path("new.txt")
fname.touch() fname.touch()
repo.git.add(str(fname)) raw_repo.git.add(str(fname))
coder = GitRepo(InputOutput(), None) git_repo = GitRepo(InputOutput(), None)
# better be there # better be there
fnames = coder.get_tracked_files() fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames) self.assertIn(str(fname), fnames)
# commit it, better still be there # commit it, better still be there
repo.git.commit("-m", "new") raw_repo.git.commit("-m", "new")
fnames = coder.get_tracked_files() fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames) self.assertIn(str(fname), fnames)
# new file, added but not committed # new file, added but not committed
fname2 = Path("new2.txt") fname2 = Path("new2.txt")
fname2.touch() fname2.touch()
repo.git.add(str(fname2)) raw_repo.git.add(str(fname2))
# both should be there # both should be there
fnames = coder.get_tracked_files() fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames) self.assertIn(str(fname), fnames)
self.assertIn(str(fname2), fnames) self.assertIn(str(fname2), fnames)