style: Apply linter fixes to onboarding tests

This commit is contained in:
Paul Gauthier (aider) 2025-03-31 08:41:11 +13:00
parent b54629addb
commit 24c074eeaa

View file

@ -1,41 +1,54 @@
import unittest
from unittest.mock import MagicMock, patch, mock_open
import os
import requests
import socketserver
import secrets
import hashlib
import base64
import argparse
import base64
import hashlib
import os
import secrets
import socketserver
import unittest
from unittest.mock import MagicMock, mock_open, patch
import requests
# Mock the Analytics class as it's used in some functions
class DummyAnalytics:
def event(self, *args, **kwargs):
pass
# Mock the InputOutput class
class DummyIO:
def tool_output(self, *args, **kwargs): pass
def tool_warning(self, *args, **kwargs): pass
def tool_error(self, *args, **kwargs): pass
def confirm_ask(self, *args, **kwargs): return False # Default to no confirmation
def offer_url(self, *args, **kwargs): pass
def tool_output(self, *args, **kwargs):
pass
def tool_warning(self, *args, **kwargs):
pass
def tool_error(self, *args, **kwargs):
pass
def confirm_ask(self, *args, **kwargs):
return False # Default to no confirmation
def offer_url(self, *args, **kwargs):
pass
# Import the functions to be tested
from aider.onboarding import (
check_openrouter_tier,
try_to_select_default_model,
select_default_model,
offer_openrouter_oauth,
exchange_code_for_key,
find_available_port,
generate_pkce_codes,
exchange_code_for_key,
offer_openrouter_oauth,
select_default_model,
start_openrouter_oauth_flow,
try_to_select_default_model,
)
class TestOnboarding(unittest.TestCase):
@patch('requests.get')
class TestOnboarding(unittest.TestCase):
@patch("requests.get")
def test_check_openrouter_tier_free(self, mock_get):
"""Test check_openrouter_tier identifies free tier."""
mock_response = MagicMock()
@ -49,7 +62,7 @@ class TestOnboarding(unittest.TestCase):
timeout=5,
)
@patch('requests.get')
@patch("requests.get")
def test_check_openrouter_tier_paid(self, mock_get):
"""Test check_openrouter_tier identifies paid tier."""
mock_response = MagicMock()
@ -58,125 +71,129 @@ class TestOnboarding(unittest.TestCase):
mock_get.return_value = mock_response
self.assertFalse(check_openrouter_tier("fake_key"))
@patch('requests.get')
@patch("requests.get")
def test_check_openrouter_tier_api_error(self, mock_get):
"""Test check_openrouter_tier defaults to free on API error."""
mock_get.side_effect = requests.exceptions.RequestException("API Error")
self.assertTrue(check_openrouter_tier("fake_key"))
@patch('requests.get')
@patch("requests.get")
def test_check_openrouter_tier_missing_key(self, mock_get):
"""Test check_openrouter_tier defaults to free if key is missing in response."""
mock_response = MagicMock()
mock_response.json.return_value = {"data": {}} # Missing 'is_free_tier'
mock_response.json.return_value = {"data": {}} # Missing 'is_free_tier'
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
self.assertTrue(check_openrouter_tier("fake_key"))
@patch('aider.onboarding.check_openrouter_tier')
@patch("aider.onboarding.check_openrouter_tier")
@patch.dict(os.environ, {}, clear=True)
def test_try_select_default_model_no_keys(self, mock_check_tier):
"""Test no model is selected when no keys are present."""
self.assertIsNone(try_to_select_default_model())
mock_check_tier.assert_not_called()
@patch('aider.onboarding.check_openrouter_tier', return_value=True) # Assume free tier
@patch("aider.onboarding.check_openrouter_tier", return_value=True) # Assume free tier
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key"}, clear=True)
def test_try_select_default_model_openrouter_free(self, mock_check_tier):
"""Test OpenRouter free model selection."""
self.assertEqual(try_to_select_default_model(), "openrouter/google/gemini-2.5-pro-exp-03-25:free")
self.assertEqual(
try_to_select_default_model(), "openrouter/google/gemini-2.5-pro-exp-03-25:free"
)
mock_check_tier.assert_called_once_with("or_key")
@patch('aider.onboarding.check_openrouter_tier', return_value=False) # Assume paid tier
@patch("aider.onboarding.check_openrouter_tier", return_value=False) # Assume paid tier
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key"}, clear=True)
def test_try_select_default_model_openrouter_paid(self, mock_check_tier):
"""Test OpenRouter paid model selection."""
self.assertEqual(try_to_select_default_model(), "openrouter/anthropic/claude-3.7-sonnet")
mock_check_tier.assert_called_once_with("or_key")
@patch('aider.onboarding.check_openrouter_tier')
@patch("aider.onboarding.check_openrouter_tier")
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "an_key"}, clear=True)
def test_try_select_default_model_anthropic(self, mock_check_tier):
"""Test Anthropic model selection."""
self.assertEqual(try_to_select_default_model(), "sonnet")
mock_check_tier.assert_not_called()
@patch('aider.onboarding.check_openrouter_tier')
@patch("aider.onboarding.check_openrouter_tier")
@patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds_key"}, clear=True)
def test_try_select_default_model_deepseek(self, mock_check_tier):
"""Test Deepseek model selection."""
self.assertEqual(try_to_select_default_model(), "deepseek")
mock_check_tier.assert_not_called()
@patch('aider.onboarding.check_openrouter_tier')
@patch("aider.onboarding.check_openrouter_tier")
@patch.dict(os.environ, {"OPENAI_API_KEY": "oa_key"}, clear=True)
def test_try_select_default_model_openai(self, mock_check_tier):
"""Test OpenAI model selection."""
self.assertEqual(try_to_select_default_model(), "gpt-4o")
mock_check_tier.assert_not_called()
@patch('aider.onboarding.check_openrouter_tier')
@patch("aider.onboarding.check_openrouter_tier")
@patch.dict(os.environ, {"GEMINI_API_KEY": "gm_key"}, clear=True)
def test_try_select_default_model_gemini(self, mock_check_tier):
"""Test Gemini model selection."""
self.assertEqual(try_to_select_default_model(), "gemini/gemini-2.5-pro-exp-03-25")
mock_check_tier.assert_not_called()
@patch('aider.onboarding.check_openrouter_tier')
@patch("aider.onboarding.check_openrouter_tier")
@patch.dict(os.environ, {"VERTEXAI_PROJECT": "vx_proj"}, clear=True)
def test_try_select_default_model_vertex(self, mock_check_tier):
"""Test Vertex AI model selection."""
self.assertEqual(try_to_select_default_model(), "vertex_ai/gemini-2.5-pro-exp-03-25")
mock_check_tier.assert_not_called()
@patch('aider.onboarding.check_openrouter_tier', return_value=False) # Paid
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key", "OPENAI_API_KEY": "oa_key"}, clear=True)
@patch("aider.onboarding.check_openrouter_tier", return_value=False) # Paid
@patch.dict(
os.environ, {"OPENROUTER_API_KEY": "or_key", "OPENAI_API_KEY": "oa_key"}, clear=True
)
def test_try_select_default_model_priority_openrouter(self, mock_check_tier):
"""Test OpenRouter key takes priority."""
self.assertEqual(try_to_select_default_model(), "openrouter/anthropic/claude-3.7-sonnet")
mock_check_tier.assert_called_once_with("or_key")
@patch('aider.onboarding.check_openrouter_tier')
@patch("aider.onboarding.check_openrouter_tier")
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "an_key", "OPENAI_API_KEY": "oa_key"}, clear=True)
def test_try_select_default_model_priority_anthropic(self, mock_check_tier):
"""Test Anthropic key takes priority over OpenAI."""
self.assertEqual(try_to_select_default_model(), "sonnet")
mock_check_tier.assert_not_called()
@patch('socketserver.TCPServer')
@patch("socketserver.TCPServer")
def test_find_available_port_success(self, mock_tcp_server):
"""Test finding an available port."""
# Simulate port 8484 being available
mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager
mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager
port = find_available_port(start_port=8484, end_port=8484)
self.assertEqual(port, 8484)
mock_tcp_server.assert_called_once_with(("localhost", 8484), None)
@patch('socketserver.TCPServer')
@patch("socketserver.TCPServer")
def test_find_available_port_in_use(self, mock_tcp_server):
"""Test finding the next available port if the first is in use."""
# Simulate port 8484 raising OSError, 8485 being available
mock_tcp_server.side_effect = [OSError, MagicMock()]
mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager
mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager
port = find_available_port(start_port=8484, end_port=8485)
self.assertEqual(port, 8485)
self.assertEqual(mock_tcp_server.call_count, 2)
mock_tcp_server.assert_any_call(("localhost", 8484), None)
mock_tcp_server.assert_any_call(("localhost", 8485), None)
@patch('socketserver.TCPServer', side_effect=OSError)
@patch("socketserver.TCPServer", side_effect=OSError)
def test_find_available_port_none_available(self, mock_tcp_server):
"""Test returning None if no ports are available in the range."""
port = find_available_port(start_port=8484, end_port=8485)
self.assertIsNone(port)
self.assertEqual(mock_tcp_server.call_count, 2) # Tried 8484 and 8485
self.assertEqual(mock_tcp_server.call_count, 2) # Tried 8484 and 8485
def test_generate_pkce_codes(self):
"""Test PKCE code generation."""
verifier, challenge = generate_pkce_codes()
self.assertIsInstance(verifier, str)
self.assertIsInstance(challenge, str)
self.assertGreater(len(verifier), 40) # Check reasonable length
self.assertGreater(len(verifier), 40) # Check reasonable length
self.assertGreater(len(challenge), 40)
# Verify the challenge is the SHA256 hash of the verifier, base64 encoded
hasher = hashlib.sha256()
@ -184,7 +201,7 @@ class TestOnboarding(unittest.TestCase):
expected_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8")
self.assertEqual(challenge, expected_challenge)
@patch('requests.post')
@patch("requests.post")
def test_exchange_code_for_key_success(self, mock_post):
"""Test successful code exchange for API key."""
mock_response = MagicMock()
@ -207,16 +224,16 @@ class TestOnboarding(unittest.TestCase):
timeout=30,
)
@patch('requests.post')
@patch("requests.post")
def test_exchange_code_for_key_missing_key(self, mock_post):
"""Test code exchange when 'key' is missing in response."""
mock_response = MagicMock()
mock_response.json.return_value = {"other_data": "value"} # Missing 'key'
mock_response.json.return_value = {"other_data": "value"} # Missing 'key'
mock_response.raise_for_status.return_value = None
mock_response.text = '{"other_data": "value"}'
mock_post.return_value = mock_response
io_mock = DummyIO()
io_mock.tool_error = MagicMock() # Track error output
io_mock.tool_error = MagicMock() # Track error output
api_key = exchange_code_for_key("auth_code", "verifier", io_mock)
@ -224,7 +241,7 @@ class TestOnboarding(unittest.TestCase):
io_mock.tool_error.assert_any_call("Error: 'key' not found in OpenRouter response.")
io_mock.tool_error.assert_any_call('Response: {"other_data": "value"}')
@patch('requests.post')
@patch("requests.post")
def test_exchange_code_for_key_http_error(self, mock_post):
"""Test code exchange with HTTP error."""
mock_response = MagicMock()
@ -244,7 +261,7 @@ class TestOnboarding(unittest.TestCase):
)
io_mock.tool_error.assert_any_call('Response: {"error": "invalid_code"}')
@patch('requests.post')
@patch("requests.post")
def test_exchange_code_for_key_timeout(self, mock_post):
"""Test code exchange with timeout."""
mock_post.side_effect = requests.exceptions.Timeout("Timeout")
@ -258,7 +275,7 @@ class TestOnboarding(unittest.TestCase):
"Error: Request to OpenRouter timed out during code exchange."
)
@patch('requests.post')
@patch("requests.post")
def test_exchange_code_for_key_request_exception(self, mock_post):
"""Test code exchange with general request exception."""
req_exception = requests.exceptions.RequestException("Network Error")
@ -275,8 +292,8 @@ class TestOnboarding(unittest.TestCase):
# --- Tests for select_default_model ---
@patch('aider.onboarding.try_to_select_default_model', return_value="gpt-4o")
@patch('aider.onboarding.offer_openrouter_oauth')
@patch("aider.onboarding.try_to_select_default_model", return_value="gpt-4o")
@patch("aider.onboarding.offer_openrouter_oauth")
def test_select_default_model_already_specified(self, mock_offer_oauth, mock_try_select):
"""Test select_default_model returns args.model if provided."""
args = argparse.Namespace(model="specific-model")
@ -287,15 +304,15 @@ class TestOnboarding(unittest.TestCase):
mock_try_select.assert_not_called()
mock_offer_oauth.assert_not_called()
@patch('aider.onboarding.try_to_select_default_model', return_value="gpt-4o")
@patch('aider.onboarding.offer_openrouter_oauth')
@patch("aider.onboarding.try_to_select_default_model", return_value="gpt-4o")
@patch("aider.onboarding.offer_openrouter_oauth")
def test_select_default_model_found_via_env(self, mock_offer_oauth, mock_try_select):
"""Test select_default_model returns model found by try_to_select."""
args = argparse.Namespace(model=None) # No model specified
args = argparse.Namespace(model=None) # No model specified
io_mock = DummyIO()
io_mock.tool_warning = MagicMock() # Track warnings
io_mock.tool_warning = MagicMock() # Track warnings
analytics_mock = DummyAnalytics()
analytics_mock.event = MagicMock() # Track events
analytics_mock.event = MagicMock() # Track events
selected_model = select_default_model(args, io_mock, analytics_mock)
@ -307,8 +324,12 @@ class TestOnboarding(unittest.TestCase):
analytics_mock.event.assert_called_once_with("auto_model_selection", model="gpt-4o")
mock_offer_oauth.assert_not_called()
@patch('aider.onboarding.try_to_select_default_model', side_effect=[None, None]) # Fails first, fails after oauth attempt
@patch('aider.onboarding.offer_openrouter_oauth', return_value=False) # OAuth offered but fails/declined
@patch(
"aider.onboarding.try_to_select_default_model", side_effect=[None, None]
) # Fails first, fails after oauth attempt
@patch(
"aider.onboarding.offer_openrouter_oauth", return_value=False
) # OAuth offered but fails/declined
def test_select_default_model_no_keys_oauth_fail(self, mock_offer_oauth, mock_try_select):
"""Test select_default_model offers OAuth when no keys, but OAuth fails."""
args = argparse.Namespace(model=None)
@ -320,15 +341,20 @@ class TestOnboarding(unittest.TestCase):
selected_model = select_default_model(args, io_mock, analytics_mock)
self.assertIsNone(selected_model)
self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth attempt
self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth attempt
mock_offer_oauth.assert_called_once_with(io_mock, analytics_mock)
io_mock.tool_warning.assert_called_once_with(
"No LLM model was specified and no API keys were provided."
)
io_mock.offer_url.assert_called_once() # Should offer docs URL
io_mock.offer_url.assert_called_once() # Should offer docs URL
@patch('aider.onboarding.try_to_select_default_model', side_effect=[None, "openrouter/google/gemini-2.5-pro-exp-03-25:free"]) # Fails first, succeeds after oauth
@patch('aider.onboarding.offer_openrouter_oauth', return_value=True) # OAuth offered and succeeds
@patch(
"aider.onboarding.try_to_select_default_model",
side_effect=[None, "openrouter/google/gemini-2.5-pro-exp-03-25:free"],
) # Fails first, succeeds after oauth
@patch(
"aider.onboarding.offer_openrouter_oauth", return_value=True
) # OAuth offered and succeeds
def test_select_default_model_no_keys_oauth_success(self, mock_offer_oauth, mock_try_select):
"""Test select_default_model offers OAuth, which succeeds."""
args = argparse.Namespace(model=None)
@ -339,7 +365,7 @@ class TestOnboarding(unittest.TestCase):
selected_model = select_default_model(args, io_mock, analytics_mock)
self.assertEqual(selected_model, "openrouter/google/gemini-2.5-pro-exp-03-25:free")
self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth
self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth
mock_offer_oauth.assert_called_once_with(io_mock, analytics_mock)
# First warning about no keys, second about using the model found after OAuth
self.assertEqual(io_mock.tool_warning.call_count, 2)
@ -353,12 +379,12 @@ class TestOnboarding(unittest.TestCase):
# Instead, let's verify the final state and model returned.
# --- Tests for offer_openrouter_oauth ---
@patch('aider.onboarding.start_openrouter_oauth_flow', return_value="new_or_key")
@patch.dict(os.environ, {}, clear=True) # Ensure no key exists initially
@patch("aider.onboarding.start_openrouter_oauth_flow", return_value="new_or_key")
@patch.dict(os.environ, {}, clear=True) # Ensure no key exists initially
def test_offer_openrouter_oauth_confirm_yes_success(self, mock_start_oauth):
"""Test offer_openrouter_oauth when user confirms and OAuth succeeds."""
io_mock = DummyIO()
io_mock.confirm_ask = MagicMock(return_value=True) # User says yes
io_mock.confirm_ask = MagicMock(return_value=True) # User says yes
analytics_mock = DummyAnalytics()
analytics_mock.event = MagicMock()
@ -373,12 +399,12 @@ class TestOnboarding(unittest.TestCase):
# Clean up env var
del os.environ["OPENROUTER_API_KEY"]
@patch('aider.onboarding.start_openrouter_oauth_flow', return_value=None) # OAuth fails
@patch("aider.onboarding.start_openrouter_oauth_flow", return_value=None) # OAuth fails
@patch.dict(os.environ, {}, clear=True)
def test_offer_openrouter_oauth_confirm_yes_fail(self, mock_start_oauth):
"""Test offer_openrouter_oauth when user confirms but OAuth fails."""
io_mock = DummyIO()
io_mock.confirm_ask = MagicMock(return_value=True) # User says yes
io_mock.confirm_ask = MagicMock(return_value=True) # User says yes
io_mock.tool_error = MagicMock()
analytics_mock = DummyAnalytics()
analytics_mock.event = MagicMock()
@ -389,16 +415,17 @@ class TestOnboarding(unittest.TestCase):
io_mock.confirm_ask.assert_called_once()
mock_start_oauth.assert_called_once_with(io_mock, analytics_mock)
self.assertNotIn("OPENROUTER_API_KEY", os.environ)
io_mock.tool_error.assert_called_once_with("OpenRouter authentication did not complete successfully.")
io_mock.tool_error.assert_called_once_with(
"OpenRouter authentication did not complete successfully."
)
analytics_mock.event.assert_any_call("oauth_flow_initiated", provider="openrouter")
analytics_mock.event.assert_any_call("oauth_flow_failure")
@patch('aider.onboarding.start_openrouter_oauth_flow')
@patch("aider.onboarding.start_openrouter_oauth_flow")
def test_offer_openrouter_oauth_confirm_no(self, mock_start_oauth):
"""Test offer_openrouter_oauth when user declines."""
io_mock = DummyIO()
io_mock.confirm_ask = MagicMock(return_value=False) # User says no
io_mock.confirm_ask = MagicMock(return_value=False) # User says no
analytics_mock = DummyAnalytics()
analytics_mock.event = MagicMock()
@ -407,23 +434,32 @@ class TestOnboarding(unittest.TestCase):
self.assertFalse(result)
io_mock.confirm_ask.assert_called_once()
mock_start_oauth.assert_not_called()
analytics_mock.event.assert_not_called() # No OAuth events if declined
analytics_mock.event.assert_not_called() # No OAuth events if declined
# --- More complex test for start_openrouter_oauth_flow (simplified) ---
# This test focuses on the successful path, mocking heavily
@patch('aider.onboarding.check_pip_install_extra', return_value=True) # Assume requests is installed
@patch('aider.onboarding.find_available_port', return_value=8484)
@patch('threading.Thread')
@patch('threading.Event')
@patch('webbrowser.open')
@patch('aider.onboarding.exchange_code_for_key', return_value="oauth_api_key")
@patch('os.makedirs')
@patch('builtins.open', new_callable=mock_open)
@patch.dict(os.environ, {}, clear=True) # Start with clean env
@patch(
"aider.onboarding.check_pip_install_extra", return_value=True
) # Assume requests is installed
@patch("aider.onboarding.find_available_port", return_value=8484)
@patch("threading.Thread")
@patch("threading.Event")
@patch("webbrowser.open")
@patch("aider.onboarding.exchange_code_for_key", return_value="oauth_api_key")
@patch("os.makedirs")
@patch("builtins.open", new_callable=mock_open)
@patch.dict(os.environ, {}, clear=True) # Start with clean env
def test_start_openrouter_oauth_flow_success_path(
self, mock_env, mock_open_file, mock_makedirs, mock_exchange, mock_webbrowser,
mock_event_cls, mock_thread_cls, mock_find_port, mock_check_pip
self,
mock_env,
mock_open_file,
mock_makedirs,
mock_exchange,
mock_webbrowser,
mock_event_cls,
mock_thread_cls,
mock_find_port,
mock_check_pip,
):
"""Test the successful path of start_openrouter_oauth_flow."""
io_mock = DummyIO()
@ -432,10 +468,10 @@ class TestOnboarding(unittest.TestCase):
# Mock threading Events: pretend server starts and callback happens quickly
mock_server_started_event = MagicMock()
mock_server_started_event.wait.return_value = True # Server started
mock_server_started_event.wait.return_value = True # Server started
mock_shutdown_event = MagicMock()
mock_shutdown_event.is_set.side_effect = [False, True] # Loop once, then shutdown
mock_shutdown_event.wait.return_value = True # Callback received before timeout
mock_shutdown_event.is_set.side_effect = [False, True] # Loop once, then shutdown
mock_shutdown_event.wait.return_value = True # Callback received before timeout
# Need to simulate the callback setting the auth_code *within* the flow
# This is tricky because it happens in a separate thread in reality.
@ -459,17 +495,19 @@ class TestOnboarding(unittest.TestCase):
self.assertEqual(api_key, "oauth_api_key")
mock_check_pip.assert_called_once()
mock_find_port.assert_called_once()
mock_thread_cls.assert_called_once() # Server thread created
mock_server_thread.start.assert_called_once() # Server thread started
mock_server_started_event.wait.assert_called_once_with(timeout=5) # Wait for server start
mock_webbrowser.assert_called_once() # Browser should be opened
mock_shutdown_event.wait.assert_called_once_with(timeout=300) # Wait for callback/timeout
mock_exchange.assert_called_once() # Code exchange attempted
mock_thread_cls.assert_called_once() # Server thread created
mock_server_thread.start.assert_called_once() # Server thread started
mock_server_started_event.wait.assert_called_once_with(timeout=5) # Wait for server start
mock_webbrowser.assert_called_once() # Browser should be opened
mock_shutdown_event.wait.assert_called_once_with(timeout=300) # Wait for callback/timeout
mock_exchange.assert_called_once() # Code exchange attempted
# Check that the key was set in the environment
self.assertEqual(os.environ.get("OPENROUTER_API_KEY"), "oauth_api_key")
# Check that saving the key was attempted
mock_makedirs.assert_called_once()
mock_open_file.assert_called_once_with(os.path.expanduser("~/.aider/oauth-keys.env"), "a", encoding="utf-8")
mock_open_file.assert_called_once_with(
os.path.expanduser("~/.aider/oauth-keys.env"), "a", encoding="utf-8"
)
mock_open_file().write.assert_called_once_with('OPENROUTER_API_KEY="oauth_api_key"\n')
# Check analytics events
analytics_mock.event.assert_any_call("oauth_flow_code_received", provider="openrouter")
@ -478,5 +516,5 @@ class TestOnboarding(unittest.TestCase):
del os.environ["OPENROUTER_API_KEY"]
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()