diff --git a/tests/test_coder.py b/tests/test_coder.py index 445aab345..80a92f902 100644 --- a/tests/test_coder.py +++ b/tests/test_coder.py @@ -12,6 +12,7 @@ from aider import models from aider.coders import Coder from aider.dump import dump # noqa: F401 from aider.io import InputOutput +from tests.utils import GitTemporaryDirectory class TestCoder(unittest.TestCase): @@ -95,47 +96,44 @@ class TestCoder(unittest.TestCase): self.assertEqual(coder.abs_fnames, expected_files) def test_check_for_ambiguous_filename_mentions_of_longer_paths(self): - # Mock the IO object - mock_io = MagicMock() + with GitTemporaryDirectory(): + io = InputOutput(pretty=False, yes=True) + coder = Coder.create(models.GPT4, None, io) - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(models.GPT4, None, mock_io) + fname = Path("file1.txt") + fname.touch() - fname = Path("file1.txt") - fname.touch() + other_fname = Path("other") / "file1.txt" + other_fname.parent.mkdir(parents=True, exist_ok=True) + other_fname.touch() - other_fname = Path("other") / "file1.txt" - other_fname.parent.mkdir(parents=True, exist_ok=True) - other_fname.touch() + mock = MagicMock() + mock.return_value = set([str(fname), str(other_fname)]) + coder.get_tracked_files = mock - mock = MagicMock() - mock.return_value = set([str(fname), str(other_fname)]) - coder.get_tracked_files = mock + # Call the check_for_file_mentions method + coder.check_for_file_mentions(f"Please check {fname}!") - # Call the check_for_file_mentions method - coder.check_for_file_mentions(f"Please check {fname}!") - - self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) + self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) def test_check_for_subdir_mention(self): - # Mock the IO object - mock_io = MagicMock() + with GitTemporaryDirectory(): + io = InputOutput(pretty=False, yes=True) + coder = Coder.create(models.GPT4, None, io) - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(models.GPT4, None, mock_io) + fname = Path("other") / "file1.txt" + fname.parent.mkdir(parents=True, exist_ok=True) + fname.touch() - fname = Path("other") / "file1.txt" - fname.parent.mkdir(parents=True, exist_ok=True) - fname.touch() + mock = MagicMock() + mock.return_value = set([str(fname)]) + coder.get_tracked_files = mock - mock = MagicMock() - mock.return_value = set([str(fname)]) - coder.get_tracked_files = mock + dump(fname) + # Call the check_for_file_mentions method + coder.check_for_file_mentions(f"Please check `{fname}`") - # Call the check_for_file_mentions method - coder.check_for_file_mentions(f"Please check `{fname}`") - - self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) + self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) def test_get_commit_message(self): # Mock the IO object diff --git a/tests/utils.py b/tests/utils.py index 1629c4425..c1b589eea 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,10 @@ +import os import tempfile +import git + +from aider.dump import dump # noqa: F401 + class IgnorantTemporaryDirectory: def __init__(self): @@ -13,3 +18,29 @@ class IgnorantTemporaryDirectory: self.temp_dir.__exit__(exc_type, exc_val, exc_tb) except OSError: pass # Ignore errors (Windows) + + +class ChdirTemporaryDirectory(IgnorantTemporaryDirectory): + def __init__(self): + self.cwd = os.getcwd() + super().__init__() + + def __enter__(self): + res = super().__enter__() + os.chdir(self.temp_dir.name) + return res + + def __exit__(self, exc_type, exc_val, exc_tb): + os.chdir(self.cwd) + super().__exit__(exc_type, exc_val, exc_tb) + + +class GitTemporaryDirectory(ChdirTemporaryDirectory): + def __enter__(self): + res = super().__enter__() + + repo = git.Repo.init() + repo.config_writer().set_value("user", "name", "Test User").release() + repo.config_writer().set_value("user", "email", "testuser@example.com").release() + + return res