diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index c50301de9..0836c75d7 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -14,7 +14,7 @@ from prompt_toolkit.output import DummyOutput from aider.coders import Coder from aider.dump import dump # noqa: F401 from aider.io import InputOutput -from aider.main import check_gitignore, main, setup_git +from aider.main import check_gitignore, load_dotenv_files, main, setup_git from aider.utils import GitTemporaryDirectory, IgnorantTemporaryDirectory, make_repo @@ -1275,6 +1275,65 @@ class TestMain(TestCase): for call in mock_io_instance.tool_warning.call_args_list: self.assertNotIn("Cost estimates may be inaccurate", call[0][0]) + def test_load_dotenv_files_override(self): + with GitTemporaryDirectory() as git_dir: + git_dir = Path(git_dir) + + # Create fake home and .aider directory + fake_home = git_dir / "fake_home" + fake_home.mkdir() + aider_dir = fake_home / ".aider" + aider_dir.mkdir() + + # Create oauth keys file + oauth_keys_file = aider_dir / "oauth-keys.env" + oauth_keys_file.write_text("OAUTH_VAR=oauth_val\nSHARED_VAR=oauth_shared\n") + + # Create git root .env file + git_root_env = git_dir / ".env" + git_root_env.write_text("GIT_VAR=git_val\nSHARED_VAR=git_shared\n") + + # Create CWD .env file in a subdir + cwd_subdir = git_dir / "subdir" + cwd_subdir.mkdir() + cwd_env = cwd_subdir / ".env" + cwd_env.write_text("CWD_VAR=cwd_val\nSHARED_VAR=cwd_shared\n") + + # Change to subdir + original_cwd = os.getcwd() + os.chdir(cwd_subdir) + + # Clear relevant env vars before test + for var in ["OAUTH_VAR", "SHARED_VAR", "GIT_VAR", "CWD_VAR"]: + if var in os.environ: + del os.environ[var] + + with patch("pathlib.Path.home", return_value=fake_home): + loaded_files = load_dotenv_files(str(git_dir), None) + + # Assert files were loaded in expected order (oauth first) + self.assertIn(str(oauth_keys_file.resolve()), loaded_files) + self.assertIn(str(git_root_env.resolve()), loaded_files) + self.assertIn(str(cwd_env.resolve()), loaded_files) + self.assertLess( + loaded_files.index(str(oauth_keys_file.resolve())), + loaded_files.index(str(git_root_env.resolve())), + ) + self.assertLess( + loaded_files.index(str(git_root_env.resolve())), + loaded_files.index(str(cwd_env.resolve())), + ) + + # Assert environment variables reflect the override order + self.assertEqual(os.environ.get("OAUTH_VAR"), "oauth_val") + self.assertEqual(os.environ.get("GIT_VAR"), "git_val") + self.assertEqual(os.environ.get("CWD_VAR"), "cwd_val") + # SHARED_VAR should be overridden by the last loaded file (cwd .env) + self.assertEqual(os.environ.get("SHARED_VAR"), "cwd_shared") + + # Restore CWD + os.chdir(original_cwd) + @patch("aider.main.InputOutput") def test_cache_without_stream_no_warning(self, MockInputOutput): mock_io_instance = MockInputOutput.return_value