diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 585e208a6..e54fb1087 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -17,7 +17,7 @@ from rich.markdown import Markdown from aider import models, prompts, utils from aider.commands import Commands -from aider.repo import AiderRepo +from aider.repo import GitRepo from aider.repomap import RepoMap from aider.sendchat import send_with_retries @@ -159,7 +159,7 @@ class Coder: if use_git: try: - self.repo = AiderRepo(self.io, fnames) + self.repo = GitRepo(self.io, fnames) self.root = self.repo.root except FileNotFoundError: self.repo = None diff --git a/aider/repo.py b/aider/repo.py index 0016c332d..0b01e9978 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -10,7 +10,7 @@ from aider.sendchat import send_with_retries from .dump import dump # noqa: F401 -class AiderRepo: +class GitRepo: repo = None def __init__(self, io, fnames): diff --git a/tests/test_main.py b/tests/test_main.py index 8b08c61b1..6bda03c7f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -38,7 +38,7 @@ class TestMain(TestCase): main(["foo.txt", "--yes", "--no-git"], input=DummyInput(), output=DummyOutput()) self.assertTrue(os.path.exists("foo.txt")) - @patch("aider.repo.AiderRepo.get_commit_message", return_value="mock commit message") + @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") def test_main_with_empty_git_dir_new_file(self, _): make_repo() main(["--yes", "foo.txt"], input=DummyInput(), output=DummyOutput()) diff --git a/tests/test_repo.py b/tests/test_repo.py index de59a4fe4..ca0985886 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -5,13 +5,10 @@ from pathlib import Path from unittest.mock import MagicMock, patch import git -import openai -from aider import models from aider.dump import dump # noqa: F401 from aider.io import InputOutput -from aider.repo import AiderRepo -from tests.utils import GitTemporaryDirectory +from aider.repo import GitRepo class TestRepo(unittest.TestCase): @@ -21,12 +18,9 @@ class TestRepo(unittest.TestCase): mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "a good commit message" - mock_send.return_value = ( - None, - mock_response - ) + mock_send.return_value = (None, mock_response) - repo = AiderRepo(InputOutput(), None) + repo = GitRepo(InputOutput(), None) # Call the get_commit_message method with dummy diff and context result = repo.get_commit_message("dummy diff", "dummy context") @@ -39,12 +33,9 @@ class TestRepo(unittest.TestCase): mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = '"a good commit message"' - mock_send.return_value = ( - None, - mock_response - ) + mock_send.return_value = (None, mock_response) - repo = AiderRepo(InputOutput(), None) + repo = GitRepo(InputOutput(), None) # Call the get_commit_message method with dummy diff and context result = repo.get_commit_message("dummy diff", "dummy context") @@ -57,19 +48,15 @@ class TestRepo(unittest.TestCase): mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = 'a good "commit message"' - mock_send.return_value = ( - None, - mock_response - ) + mock_send.return_value = (None, mock_response) - repo = AiderRepo(InputOutput(), None) + repo = GitRepo(InputOutput(), None) # Call the get_commit_message method with dummy diff and context result = repo.get_commit_message("dummy diff", "dummy context") # Assert that the returned message is the expected one self.assertEqual(result, 'a good "commit message"') - def test_get_tracked_files(self): # Create a temporary directory tempdir = Path(tempfile.mkdtemp()) @@ -98,8 +85,7 @@ class TestRepo(unittest.TestCase): repo.git.commit("-m", "added") - - tracked_files = AiderRepo(InputOutput(), [tempdir]).get_tracked_files() + tracked_files = GitRepo(InputOutput(), [tempdir]).get_tracked_files() # On windows, paths will come back \like\this, so normalize them back to Paths tracked_files = [Path(fn) for fn in tracked_files]