From 9ceb766a67365ea2edb7800595c05ce01138a69d Mon Sep 17 00:00:00 2001 From: "Paul Gauthier (aider)" Date: Wed, 5 Mar 2025 18:39:17 -0800 Subject: [PATCH] feat: Add comprehensive tests for SSL verification in ModelInfoManager and main --- tests/basic/test_main.py | 10 +++ tests/basic/test_model_info_manager.py | 86 ++++++++++++++++++++++++++ tests/basic/test_ssl_verification.py | 61 ++++++++++++++++++ 3 files changed, 157 insertions(+) create mode 100644 tests/basic/test_model_info_manager.py create mode 100644 tests/basic/test_ssl_verification.py diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index 6ea29f5dd..c01553903 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -683,6 +683,16 @@ class TestMain(TestCase): return_coder=True, ) self.assertTrue(coder.detect_urls) + + @patch("aider.models.ModelInfoManager.set_verify_ssl") + def test_no_verify_ssl_sets_model_info_manager(self, mock_set_verify_ssl): + with GitTemporaryDirectory(): + main( + ["--no-verify-ssl", "--exit", "--yes"], + input=DummyInput(), + output=DummyOutput(), + ) + mock_set_verify_ssl.assert_called_once_with(False) def test_pytest_env_vars(self): # Verify that environment variables from pytest.ini are properly set diff --git a/tests/basic/test_model_info_manager.py b/tests/basic/test_model_info_manager.py new file mode 100644 index 000000000..4e90dfe71 --- /dev/null +++ b/tests/basic/test_model_info_manager.py @@ -0,0 +1,86 @@ +import os +import tempfile +from pathlib import Path +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from aider.models import ModelInfoManager + + +class TestModelInfoManager(TestCase): + def setUp(self): + self.original_env = os.environ.copy() + self.manager = ModelInfoManager() + # Create a temporary directory for cache + self.temp_dir = tempfile.TemporaryDirectory() + self.manager.cache_dir = Path(self.temp_dir.name) + self.manager.cache_file = self.manager.cache_dir / "model_prices_and_context_window.json" + self.manager.cache_dir.mkdir(exist_ok=True) + + def tearDown(self): + self.temp_dir.cleanup() + os.environ.clear() + os.environ.update(self.original_env) + + @patch("requests.get") + def test_update_cache_respects_verify_ssl(self, mock_get): + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"test_model": {"max_tokens": 4096}} + mock_get.return_value = mock_response + + # Test with default verify_ssl=True + self.manager._update_cache() + mock_get.assert_called_with( + self.manager.MODEL_INFO_URL, timeout=5, verify=True + ) + + # Test with verify_ssl=False + mock_get.reset_mock() + self.manager.set_verify_ssl(False) + self.manager._update_cache() + mock_get.assert_called_with( + self.manager.MODEL_INFO_URL, timeout=5, verify=False + ) + + def test_lazy_loading_cache(self): + # Create a cache file + self.manager.cache_file.write_text('{"test_model": {"max_tokens": 4096}}') + + # Verify cache is not loaded on initialization + self.assertFalse(self.manager._cache_loaded) + self.assertIsNone(self.manager.content) + + # Access content through get_model_from_cached_json_db + with patch.object(self.manager, "_update_cache") as mock_update: + result = self.manager.get_model_from_cached_json_db("test_model") + + # Verify cache was loaded + self.assertTrue(self.manager._cache_loaded) + self.assertIsNotNone(self.manager.content) + self.assertEqual(result, {"max_tokens": 4096}) + + # Verify _update_cache was not called since cache exists and is valid + mock_update.assert_not_called() + + @patch("requests.get") + def test_verify_ssl_setting_before_cache_loading(self, mock_get): + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"test_model": {"max_tokens": 4096}} + mock_get.return_value = mock_response + + # Set verify_ssl to False before any cache operations + self.manager.set_verify_ssl(False) + + # Force cache update by making it look expired + with patch("time.time", return_value=9999999999): + # This should trigger _update_cache + result = self.manager.get_model_from_cached_json_db("test_model") + + # Verify _update_cache was called with verify=False + mock_get.assert_called_with( + self.manager.MODEL_INFO_URL, timeout=5, verify=False + ) diff --git a/tests/basic/test_ssl_verification.py b/tests/basic/test_ssl_verification.py new file mode 100644 index 000000000..93a8e8fbf --- /dev/null +++ b/tests/basic/test_ssl_verification.py @@ -0,0 +1,61 @@ +import os +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from prompt_toolkit.input import DummyInput +from prompt_toolkit.output import DummyOutput + +from aider.main import main + + +class TestSSLVerification(TestCase): + def setUp(self): + self.original_env = os.environ.copy() + os.environ["OPENAI_API_KEY"] = "test-key" + os.environ["AIDER_CHECK_UPDATE"] = "false" + os.environ["AIDER_ANALYTICS"] = "false" + + def tearDown(self): + os.environ.clear() + os.environ.update(self.original_env) + + @patch("aider.models.model_info_manager.set_verify_ssl") + @patch("aider.llm.litellm._load_litellm") + @patch("httpx.Client") + @patch("httpx.AsyncClient") + def test_no_verify_ssl_flag_sets_model_info_manager( + self, mock_async_client, mock_client, mock_load_litellm, mock_set_verify_ssl + ): + # Mock the litellm._lazy_module to avoid AttributeError + mock_load_litellm.return_value = None + mock_module = MagicMock() + with patch("aider.llm.litellm._lazy_module", mock_module): + # Run main with --no-verify-ssl flag + main( + ["--no-verify-ssl", "--exit", "--yes"], + input=DummyInput(), + output=DummyOutput(), + ) + + # Verify model_info_manager.set_verify_ssl was called with False + mock_set_verify_ssl.assert_called_once_with(False) + + # Verify httpx clients were created with verify=False + mock_client.assert_called_once_with(verify=False) + mock_async_client.assert_called_once_with(verify=False) + + # Verify SSL_VERIFY environment variable was set to empty string + self.assertEqual(os.environ.get("SSL_VERIFY"), "") + + @patch("aider.models.model_info_manager.set_verify_ssl") + def test_default_ssl_verification(self, mock_set_verify_ssl): + # Run main without --no-verify-ssl flag + with patch("aider.main.InputOutput"): + with patch("aider.coders.Coder.create"): + main(["--exit", "--yes"], input=DummyInput(), output=DummyOutput()) + + # Verify model_info_manager.set_verify_ssl was not called + mock_set_verify_ssl.assert_not_called() + + # Verify SSL_VERIFY environment variable was not set + self.assertNotIn("SSL_VERIFY", os.environ)