diff --git a/tests/basic/test_onboarding.py b/tests/basic/test_onboarding.py index 5854a640d..bd2005edc 100644 --- a/tests/basic/test_onboarding.py +++ b/tests/basic/test_onboarding.py @@ -1,41 +1,54 @@ -import unittest -from unittest.mock import MagicMock, patch, mock_open -import os -import requests -import socketserver -import secrets -import hashlib -import base64 import argparse +import base64 +import hashlib +import os +import secrets +import socketserver +import unittest +from unittest.mock import MagicMock, mock_open, patch + +import requests + # Mock the Analytics class as it's used in some functions class DummyAnalytics: def event(self, *args, **kwargs): pass + # Mock the InputOutput class class DummyIO: - def tool_output(self, *args, **kwargs): pass - def tool_warning(self, *args, **kwargs): pass - def tool_error(self, *args, **kwargs): pass - def confirm_ask(self, *args, **kwargs): return False # Default to no confirmation - def offer_url(self, *args, **kwargs): pass + def tool_output(self, *args, **kwargs): + pass + + def tool_warning(self, *args, **kwargs): + pass + + def tool_error(self, *args, **kwargs): + pass + + def confirm_ask(self, *args, **kwargs): + return False # Default to no confirmation + + def offer_url(self, *args, **kwargs): + pass + # Import the functions to be tested from aider.onboarding import ( check_openrouter_tier, - try_to_select_default_model, - select_default_model, - offer_openrouter_oauth, + exchange_code_for_key, find_available_port, generate_pkce_codes, - exchange_code_for_key, + offer_openrouter_oauth, + select_default_model, start_openrouter_oauth_flow, + try_to_select_default_model, ) -class TestOnboarding(unittest.TestCase): - @patch('requests.get') +class TestOnboarding(unittest.TestCase): + @patch("requests.get") def test_check_openrouter_tier_free(self, mock_get): """Test check_openrouter_tier identifies free tier.""" mock_response = MagicMock() @@ -49,7 +62,7 @@ class TestOnboarding(unittest.TestCase): timeout=5, ) - @patch('requests.get') + @patch("requests.get") def test_check_openrouter_tier_paid(self, mock_get): """Test check_openrouter_tier identifies paid tier.""" mock_response = MagicMock() @@ -58,125 +71,129 @@ class TestOnboarding(unittest.TestCase): mock_get.return_value = mock_response self.assertFalse(check_openrouter_tier("fake_key")) - @patch('requests.get') + @patch("requests.get") def test_check_openrouter_tier_api_error(self, mock_get): """Test check_openrouter_tier defaults to free on API error.""" mock_get.side_effect = requests.exceptions.RequestException("API Error") self.assertTrue(check_openrouter_tier("fake_key")) - @patch('requests.get') + @patch("requests.get") def test_check_openrouter_tier_missing_key(self, mock_get): """Test check_openrouter_tier defaults to free if key is missing in response.""" mock_response = MagicMock() - mock_response.json.return_value = {"data": {}} # Missing 'is_free_tier' + mock_response.json.return_value = {"data": {}} # Missing 'is_free_tier' mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response self.assertTrue(check_openrouter_tier("fake_key")) - @patch('aider.onboarding.check_openrouter_tier') + @patch("aider.onboarding.check_openrouter_tier") @patch.dict(os.environ, {}, clear=True) def test_try_select_default_model_no_keys(self, mock_check_tier): """Test no model is selected when no keys are present.""" self.assertIsNone(try_to_select_default_model()) mock_check_tier.assert_not_called() - @patch('aider.onboarding.check_openrouter_tier', return_value=True) # Assume free tier + @patch("aider.onboarding.check_openrouter_tier", return_value=True) # Assume free tier @patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key"}, clear=True) def test_try_select_default_model_openrouter_free(self, mock_check_tier): """Test OpenRouter free model selection.""" - self.assertEqual(try_to_select_default_model(), "openrouter/google/gemini-2.5-pro-exp-03-25:free") + self.assertEqual( + try_to_select_default_model(), "openrouter/google/gemini-2.5-pro-exp-03-25:free" + ) mock_check_tier.assert_called_once_with("or_key") - @patch('aider.onboarding.check_openrouter_tier', return_value=False) # Assume paid tier + @patch("aider.onboarding.check_openrouter_tier", return_value=False) # Assume paid tier @patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key"}, clear=True) def test_try_select_default_model_openrouter_paid(self, mock_check_tier): """Test OpenRouter paid model selection.""" self.assertEqual(try_to_select_default_model(), "openrouter/anthropic/claude-3.7-sonnet") mock_check_tier.assert_called_once_with("or_key") - @patch('aider.onboarding.check_openrouter_tier') + @patch("aider.onboarding.check_openrouter_tier") @patch.dict(os.environ, {"ANTHROPIC_API_KEY": "an_key"}, clear=True) def test_try_select_default_model_anthropic(self, mock_check_tier): """Test Anthropic model selection.""" self.assertEqual(try_to_select_default_model(), "sonnet") mock_check_tier.assert_not_called() - @patch('aider.onboarding.check_openrouter_tier') + @patch("aider.onboarding.check_openrouter_tier") @patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds_key"}, clear=True) def test_try_select_default_model_deepseek(self, mock_check_tier): """Test Deepseek model selection.""" self.assertEqual(try_to_select_default_model(), "deepseek") mock_check_tier.assert_not_called() - @patch('aider.onboarding.check_openrouter_tier') + @patch("aider.onboarding.check_openrouter_tier") @patch.dict(os.environ, {"OPENAI_API_KEY": "oa_key"}, clear=True) def test_try_select_default_model_openai(self, mock_check_tier): """Test OpenAI model selection.""" self.assertEqual(try_to_select_default_model(), "gpt-4o") mock_check_tier.assert_not_called() - @patch('aider.onboarding.check_openrouter_tier') + @patch("aider.onboarding.check_openrouter_tier") @patch.dict(os.environ, {"GEMINI_API_KEY": "gm_key"}, clear=True) def test_try_select_default_model_gemini(self, mock_check_tier): """Test Gemini model selection.""" self.assertEqual(try_to_select_default_model(), "gemini/gemini-2.5-pro-exp-03-25") mock_check_tier.assert_not_called() - @patch('aider.onboarding.check_openrouter_tier') + @patch("aider.onboarding.check_openrouter_tier") @patch.dict(os.environ, {"VERTEXAI_PROJECT": "vx_proj"}, clear=True) def test_try_select_default_model_vertex(self, mock_check_tier): """Test Vertex AI model selection.""" self.assertEqual(try_to_select_default_model(), "vertex_ai/gemini-2.5-pro-exp-03-25") mock_check_tier.assert_not_called() - @patch('aider.onboarding.check_openrouter_tier', return_value=False) # Paid - @patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key", "OPENAI_API_KEY": "oa_key"}, clear=True) + @patch("aider.onboarding.check_openrouter_tier", return_value=False) # Paid + @patch.dict( + os.environ, {"OPENROUTER_API_KEY": "or_key", "OPENAI_API_KEY": "oa_key"}, clear=True + ) def test_try_select_default_model_priority_openrouter(self, mock_check_tier): """Test OpenRouter key takes priority.""" self.assertEqual(try_to_select_default_model(), "openrouter/anthropic/claude-3.7-sonnet") mock_check_tier.assert_called_once_with("or_key") - @patch('aider.onboarding.check_openrouter_tier') + @patch("aider.onboarding.check_openrouter_tier") @patch.dict(os.environ, {"ANTHROPIC_API_KEY": "an_key", "OPENAI_API_KEY": "oa_key"}, clear=True) def test_try_select_default_model_priority_anthropic(self, mock_check_tier): """Test Anthropic key takes priority over OpenAI.""" self.assertEqual(try_to_select_default_model(), "sonnet") mock_check_tier.assert_not_called() - @patch('socketserver.TCPServer') + @patch("socketserver.TCPServer") def test_find_available_port_success(self, mock_tcp_server): """Test finding an available port.""" # Simulate port 8484 being available - mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager + mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager port = find_available_port(start_port=8484, end_port=8484) self.assertEqual(port, 8484) mock_tcp_server.assert_called_once_with(("localhost", 8484), None) - @patch('socketserver.TCPServer') + @patch("socketserver.TCPServer") def test_find_available_port_in_use(self, mock_tcp_server): """Test finding the next available port if the first is in use.""" # Simulate port 8484 raising OSError, 8485 being available mock_tcp_server.side_effect = [OSError, MagicMock()] - mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager + mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager port = find_available_port(start_port=8484, end_port=8485) self.assertEqual(port, 8485) self.assertEqual(mock_tcp_server.call_count, 2) mock_tcp_server.assert_any_call(("localhost", 8484), None) mock_tcp_server.assert_any_call(("localhost", 8485), None) - @patch('socketserver.TCPServer', side_effect=OSError) + @patch("socketserver.TCPServer", side_effect=OSError) def test_find_available_port_none_available(self, mock_tcp_server): """Test returning None if no ports are available in the range.""" port = find_available_port(start_port=8484, end_port=8485) self.assertIsNone(port) - self.assertEqual(mock_tcp_server.call_count, 2) # Tried 8484 and 8485 + self.assertEqual(mock_tcp_server.call_count, 2) # Tried 8484 and 8485 def test_generate_pkce_codes(self): """Test PKCE code generation.""" verifier, challenge = generate_pkce_codes() self.assertIsInstance(verifier, str) self.assertIsInstance(challenge, str) - self.assertGreater(len(verifier), 40) # Check reasonable length + self.assertGreater(len(verifier), 40) # Check reasonable length self.assertGreater(len(challenge), 40) # Verify the challenge is the SHA256 hash of the verifier, base64 encoded hasher = hashlib.sha256() @@ -184,7 +201,7 @@ class TestOnboarding(unittest.TestCase): expected_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8") self.assertEqual(challenge, expected_challenge) - @patch('requests.post') + @patch("requests.post") def test_exchange_code_for_key_success(self, mock_post): """Test successful code exchange for API key.""" mock_response = MagicMock() @@ -207,16 +224,16 @@ class TestOnboarding(unittest.TestCase): timeout=30, ) - @patch('requests.post') + @patch("requests.post") def test_exchange_code_for_key_missing_key(self, mock_post): """Test code exchange when 'key' is missing in response.""" mock_response = MagicMock() - mock_response.json.return_value = {"other_data": "value"} # Missing 'key' + mock_response.json.return_value = {"other_data": "value"} # Missing 'key' mock_response.raise_for_status.return_value = None mock_response.text = '{"other_data": "value"}' mock_post.return_value = mock_response io_mock = DummyIO() - io_mock.tool_error = MagicMock() # Track error output + io_mock.tool_error = MagicMock() # Track error output api_key = exchange_code_for_key("auth_code", "verifier", io_mock) @@ -224,7 +241,7 @@ class TestOnboarding(unittest.TestCase): io_mock.tool_error.assert_any_call("Error: 'key' not found in OpenRouter response.") io_mock.tool_error.assert_any_call('Response: {"other_data": "value"}') - @patch('requests.post') + @patch("requests.post") def test_exchange_code_for_key_http_error(self, mock_post): """Test code exchange with HTTP error.""" mock_response = MagicMock() @@ -244,7 +261,7 @@ class TestOnboarding(unittest.TestCase): ) io_mock.tool_error.assert_any_call('Response: {"error": "invalid_code"}') - @patch('requests.post') + @patch("requests.post") def test_exchange_code_for_key_timeout(self, mock_post): """Test code exchange with timeout.""" mock_post.side_effect = requests.exceptions.Timeout("Timeout") @@ -258,7 +275,7 @@ class TestOnboarding(unittest.TestCase): "Error: Request to OpenRouter timed out during code exchange." ) - @patch('requests.post') + @patch("requests.post") def test_exchange_code_for_key_request_exception(self, mock_post): """Test code exchange with general request exception.""" req_exception = requests.exceptions.RequestException("Network Error") @@ -275,8 +292,8 @@ class TestOnboarding(unittest.TestCase): # --- Tests for select_default_model --- - @patch('aider.onboarding.try_to_select_default_model', return_value="gpt-4o") - @patch('aider.onboarding.offer_openrouter_oauth') + @patch("aider.onboarding.try_to_select_default_model", return_value="gpt-4o") + @patch("aider.onboarding.offer_openrouter_oauth") def test_select_default_model_already_specified(self, mock_offer_oauth, mock_try_select): """Test select_default_model returns args.model if provided.""" args = argparse.Namespace(model="specific-model") @@ -287,15 +304,15 @@ class TestOnboarding(unittest.TestCase): mock_try_select.assert_not_called() mock_offer_oauth.assert_not_called() - @patch('aider.onboarding.try_to_select_default_model', return_value="gpt-4o") - @patch('aider.onboarding.offer_openrouter_oauth') + @patch("aider.onboarding.try_to_select_default_model", return_value="gpt-4o") + @patch("aider.onboarding.offer_openrouter_oauth") def test_select_default_model_found_via_env(self, mock_offer_oauth, mock_try_select): """Test select_default_model returns model found by try_to_select.""" - args = argparse.Namespace(model=None) # No model specified + args = argparse.Namespace(model=None) # No model specified io_mock = DummyIO() - io_mock.tool_warning = MagicMock() # Track warnings + io_mock.tool_warning = MagicMock() # Track warnings analytics_mock = DummyAnalytics() - analytics_mock.event = MagicMock() # Track events + analytics_mock.event = MagicMock() # Track events selected_model = select_default_model(args, io_mock, analytics_mock) @@ -307,8 +324,12 @@ class TestOnboarding(unittest.TestCase): analytics_mock.event.assert_called_once_with("auto_model_selection", model="gpt-4o") mock_offer_oauth.assert_not_called() - @patch('aider.onboarding.try_to_select_default_model', side_effect=[None, None]) # Fails first, fails after oauth attempt - @patch('aider.onboarding.offer_openrouter_oauth', return_value=False) # OAuth offered but fails/declined + @patch( + "aider.onboarding.try_to_select_default_model", side_effect=[None, None] + ) # Fails first, fails after oauth attempt + @patch( + "aider.onboarding.offer_openrouter_oauth", return_value=False + ) # OAuth offered but fails/declined def test_select_default_model_no_keys_oauth_fail(self, mock_offer_oauth, mock_try_select): """Test select_default_model offers OAuth when no keys, but OAuth fails.""" args = argparse.Namespace(model=None) @@ -320,15 +341,20 @@ class TestOnboarding(unittest.TestCase): selected_model = select_default_model(args, io_mock, analytics_mock) self.assertIsNone(selected_model) - self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth attempt + self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth attempt mock_offer_oauth.assert_called_once_with(io_mock, analytics_mock) io_mock.tool_warning.assert_called_once_with( "No LLM model was specified and no API keys were provided." ) - io_mock.offer_url.assert_called_once() # Should offer docs URL + io_mock.offer_url.assert_called_once() # Should offer docs URL - @patch('aider.onboarding.try_to_select_default_model', side_effect=[None, "openrouter/google/gemini-2.5-pro-exp-03-25:free"]) # Fails first, succeeds after oauth - @patch('aider.onboarding.offer_openrouter_oauth', return_value=True) # OAuth offered and succeeds + @patch( + "aider.onboarding.try_to_select_default_model", + side_effect=[None, "openrouter/google/gemini-2.5-pro-exp-03-25:free"], + ) # Fails first, succeeds after oauth + @patch( + "aider.onboarding.offer_openrouter_oauth", return_value=True + ) # OAuth offered and succeeds def test_select_default_model_no_keys_oauth_success(self, mock_offer_oauth, mock_try_select): """Test select_default_model offers OAuth, which succeeds.""" args = argparse.Namespace(model=None) @@ -339,7 +365,7 @@ class TestOnboarding(unittest.TestCase): selected_model = select_default_model(args, io_mock, analytics_mock) self.assertEqual(selected_model, "openrouter/google/gemini-2.5-pro-exp-03-25:free") - self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth + self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth mock_offer_oauth.assert_called_once_with(io_mock, analytics_mock) # First warning about no keys, second about using the model found after OAuth self.assertEqual(io_mock.tool_warning.call_count, 2) @@ -353,12 +379,12 @@ class TestOnboarding(unittest.TestCase): # Instead, let's verify the final state and model returned. # --- Tests for offer_openrouter_oauth --- - @patch('aider.onboarding.start_openrouter_oauth_flow', return_value="new_or_key") - @patch.dict(os.environ, {}, clear=True) # Ensure no key exists initially + @patch("aider.onboarding.start_openrouter_oauth_flow", return_value="new_or_key") + @patch.dict(os.environ, {}, clear=True) # Ensure no key exists initially def test_offer_openrouter_oauth_confirm_yes_success(self, mock_start_oauth): """Test offer_openrouter_oauth when user confirms and OAuth succeeds.""" io_mock = DummyIO() - io_mock.confirm_ask = MagicMock(return_value=True) # User says yes + io_mock.confirm_ask = MagicMock(return_value=True) # User says yes analytics_mock = DummyAnalytics() analytics_mock.event = MagicMock() @@ -373,12 +399,12 @@ class TestOnboarding(unittest.TestCase): # Clean up env var del os.environ["OPENROUTER_API_KEY"] - @patch('aider.onboarding.start_openrouter_oauth_flow', return_value=None) # OAuth fails + @patch("aider.onboarding.start_openrouter_oauth_flow", return_value=None) # OAuth fails @patch.dict(os.environ, {}, clear=True) def test_offer_openrouter_oauth_confirm_yes_fail(self, mock_start_oauth): """Test offer_openrouter_oauth when user confirms but OAuth fails.""" io_mock = DummyIO() - io_mock.confirm_ask = MagicMock(return_value=True) # User says yes + io_mock.confirm_ask = MagicMock(return_value=True) # User says yes io_mock.tool_error = MagicMock() analytics_mock = DummyAnalytics() analytics_mock.event = MagicMock() @@ -389,16 +415,17 @@ class TestOnboarding(unittest.TestCase): io_mock.confirm_ask.assert_called_once() mock_start_oauth.assert_called_once_with(io_mock, analytics_mock) self.assertNotIn("OPENROUTER_API_KEY", os.environ) - io_mock.tool_error.assert_called_once_with("OpenRouter authentication did not complete successfully.") + io_mock.tool_error.assert_called_once_with( + "OpenRouter authentication did not complete successfully." + ) analytics_mock.event.assert_any_call("oauth_flow_initiated", provider="openrouter") analytics_mock.event.assert_any_call("oauth_flow_failure") - - @patch('aider.onboarding.start_openrouter_oauth_flow') + @patch("aider.onboarding.start_openrouter_oauth_flow") def test_offer_openrouter_oauth_confirm_no(self, mock_start_oauth): """Test offer_openrouter_oauth when user declines.""" io_mock = DummyIO() - io_mock.confirm_ask = MagicMock(return_value=False) # User says no + io_mock.confirm_ask = MagicMock(return_value=False) # User says no analytics_mock = DummyAnalytics() analytics_mock.event = MagicMock() @@ -407,23 +434,32 @@ class TestOnboarding(unittest.TestCase): self.assertFalse(result) io_mock.confirm_ask.assert_called_once() mock_start_oauth.assert_not_called() - analytics_mock.event.assert_not_called() # No OAuth events if declined - + analytics_mock.event.assert_not_called() # No OAuth events if declined # --- More complex test for start_openrouter_oauth_flow (simplified) --- # This test focuses on the successful path, mocking heavily - @patch('aider.onboarding.check_pip_install_extra', return_value=True) # Assume requests is installed - @patch('aider.onboarding.find_available_port', return_value=8484) - @patch('threading.Thread') - @patch('threading.Event') - @patch('webbrowser.open') - @patch('aider.onboarding.exchange_code_for_key', return_value="oauth_api_key") - @patch('os.makedirs') - @patch('builtins.open', new_callable=mock_open) - @patch.dict(os.environ, {}, clear=True) # Start with clean env + @patch( + "aider.onboarding.check_pip_install_extra", return_value=True + ) # Assume requests is installed + @patch("aider.onboarding.find_available_port", return_value=8484) + @patch("threading.Thread") + @patch("threading.Event") + @patch("webbrowser.open") + @patch("aider.onboarding.exchange_code_for_key", return_value="oauth_api_key") + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open) + @patch.dict(os.environ, {}, clear=True) # Start with clean env def test_start_openrouter_oauth_flow_success_path( - self, mock_env, mock_open_file, mock_makedirs, mock_exchange, mock_webbrowser, - mock_event_cls, mock_thread_cls, mock_find_port, mock_check_pip + self, + mock_env, + mock_open_file, + mock_makedirs, + mock_exchange, + mock_webbrowser, + mock_event_cls, + mock_thread_cls, + mock_find_port, + mock_check_pip, ): """Test the successful path of start_openrouter_oauth_flow.""" io_mock = DummyIO() @@ -432,10 +468,10 @@ class TestOnboarding(unittest.TestCase): # Mock threading Events: pretend server starts and callback happens quickly mock_server_started_event = MagicMock() - mock_server_started_event.wait.return_value = True # Server started + mock_server_started_event.wait.return_value = True # Server started mock_shutdown_event = MagicMock() - mock_shutdown_event.is_set.side_effect = [False, True] # Loop once, then shutdown - mock_shutdown_event.wait.return_value = True # Callback received before timeout + mock_shutdown_event.is_set.side_effect = [False, True] # Loop once, then shutdown + mock_shutdown_event.wait.return_value = True # Callback received before timeout # Need to simulate the callback setting the auth_code *within* the flow # This is tricky because it happens in a separate thread in reality. @@ -459,17 +495,19 @@ class TestOnboarding(unittest.TestCase): self.assertEqual(api_key, "oauth_api_key") mock_check_pip.assert_called_once() mock_find_port.assert_called_once() - mock_thread_cls.assert_called_once() # Server thread created - mock_server_thread.start.assert_called_once() # Server thread started - mock_server_started_event.wait.assert_called_once_with(timeout=5) # Wait for server start - mock_webbrowser.assert_called_once() # Browser should be opened - mock_shutdown_event.wait.assert_called_once_with(timeout=300) # Wait for callback/timeout - mock_exchange.assert_called_once() # Code exchange attempted + mock_thread_cls.assert_called_once() # Server thread created + mock_server_thread.start.assert_called_once() # Server thread started + mock_server_started_event.wait.assert_called_once_with(timeout=5) # Wait for server start + mock_webbrowser.assert_called_once() # Browser should be opened + mock_shutdown_event.wait.assert_called_once_with(timeout=300) # Wait for callback/timeout + mock_exchange.assert_called_once() # Code exchange attempted # Check that the key was set in the environment self.assertEqual(os.environ.get("OPENROUTER_API_KEY"), "oauth_api_key") # Check that saving the key was attempted mock_makedirs.assert_called_once() - mock_open_file.assert_called_once_with(os.path.expanduser("~/.aider/oauth-keys.env"), "a", encoding="utf-8") + mock_open_file.assert_called_once_with( + os.path.expanduser("~/.aider/oauth-keys.env"), "a", encoding="utf-8" + ) mock_open_file().write.assert_called_once_with('OPENROUTER_API_KEY="oauth_api_key"\n') # Check analytics events analytics_mock.event.assert_any_call("oauth_flow_code_received", provider="openrouter") @@ -478,5 +516,5 @@ class TestOnboarding(unittest.TestCase): del os.environ["OPENROUTER_API_KEY"] -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()