mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 16:54:59 +00:00
style: Apply linter fixes to onboarding tests
This commit is contained in:
parent
b54629addb
commit
24c074eeaa
1 changed files with 134 additions and 96 deletions
|
@ -1,41 +1,54 @@
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch, mock_open
|
|
||||||
import os
|
|
||||||
import requests
|
|
||||||
import socketserver
|
|
||||||
import secrets
|
|
||||||
import hashlib
|
|
||||||
import base64
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import socketserver
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, mock_open, patch
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
# Mock the Analytics class as it's used in some functions
|
# Mock the Analytics class as it's used in some functions
|
||||||
class DummyAnalytics:
|
class DummyAnalytics:
|
||||||
def event(self, *args, **kwargs):
|
def event(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Mock the InputOutput class
|
# Mock the InputOutput class
|
||||||
class DummyIO:
|
class DummyIO:
|
||||||
def tool_output(self, *args, **kwargs): pass
|
def tool_output(self, *args, **kwargs):
|
||||||
def tool_warning(self, *args, **kwargs): pass
|
pass
|
||||||
def tool_error(self, *args, **kwargs): pass
|
|
||||||
def confirm_ask(self, *args, **kwargs): return False # Default to no confirmation
|
def tool_warning(self, *args, **kwargs):
|
||||||
def offer_url(self, *args, **kwargs): pass
|
pass
|
||||||
|
|
||||||
|
def tool_error(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def confirm_ask(self, *args, **kwargs):
|
||||||
|
return False # Default to no confirmation
|
||||||
|
|
||||||
|
def offer_url(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Import the functions to be tested
|
# Import the functions to be tested
|
||||||
from aider.onboarding import (
|
from aider.onboarding import (
|
||||||
check_openrouter_tier,
|
check_openrouter_tier,
|
||||||
try_to_select_default_model,
|
exchange_code_for_key,
|
||||||
select_default_model,
|
|
||||||
offer_openrouter_oauth,
|
|
||||||
find_available_port,
|
find_available_port,
|
||||||
generate_pkce_codes,
|
generate_pkce_codes,
|
||||||
exchange_code_for_key,
|
offer_openrouter_oauth,
|
||||||
|
select_default_model,
|
||||||
start_openrouter_oauth_flow,
|
start_openrouter_oauth_flow,
|
||||||
|
try_to_select_default_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
class TestOnboarding(unittest.TestCase):
|
|
||||||
|
|
||||||
@patch('requests.get')
|
class TestOnboarding(unittest.TestCase):
|
||||||
|
@patch("requests.get")
|
||||||
def test_check_openrouter_tier_free(self, mock_get):
|
def test_check_openrouter_tier_free(self, mock_get):
|
||||||
"""Test check_openrouter_tier identifies free tier."""
|
"""Test check_openrouter_tier identifies free tier."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
@ -49,7 +62,7 @@ class TestOnboarding(unittest.TestCase):
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('requests.get')
|
@patch("requests.get")
|
||||||
def test_check_openrouter_tier_paid(self, mock_get):
|
def test_check_openrouter_tier_paid(self, mock_get):
|
||||||
"""Test check_openrouter_tier identifies paid tier."""
|
"""Test check_openrouter_tier identifies paid tier."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
@ -58,125 +71,129 @@ class TestOnboarding(unittest.TestCase):
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
self.assertFalse(check_openrouter_tier("fake_key"))
|
self.assertFalse(check_openrouter_tier("fake_key"))
|
||||||
|
|
||||||
@patch('requests.get')
|
@patch("requests.get")
|
||||||
def test_check_openrouter_tier_api_error(self, mock_get):
|
def test_check_openrouter_tier_api_error(self, mock_get):
|
||||||
"""Test check_openrouter_tier defaults to free on API error."""
|
"""Test check_openrouter_tier defaults to free on API error."""
|
||||||
mock_get.side_effect = requests.exceptions.RequestException("API Error")
|
mock_get.side_effect = requests.exceptions.RequestException("API Error")
|
||||||
self.assertTrue(check_openrouter_tier("fake_key"))
|
self.assertTrue(check_openrouter_tier("fake_key"))
|
||||||
|
|
||||||
@patch('requests.get')
|
@patch("requests.get")
|
||||||
def test_check_openrouter_tier_missing_key(self, mock_get):
|
def test_check_openrouter_tier_missing_key(self, mock_get):
|
||||||
"""Test check_openrouter_tier defaults to free if key is missing in response."""
|
"""Test check_openrouter_tier defaults to free if key is missing in response."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {"data": {}} # Missing 'is_free_tier'
|
mock_response.json.return_value = {"data": {}} # Missing 'is_free_tier'
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
self.assertTrue(check_openrouter_tier("fake_key"))
|
self.assertTrue(check_openrouter_tier("fake_key"))
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier')
|
@patch("aider.onboarding.check_openrouter_tier")
|
||||||
@patch.dict(os.environ, {}, clear=True)
|
@patch.dict(os.environ, {}, clear=True)
|
||||||
def test_try_select_default_model_no_keys(self, mock_check_tier):
|
def test_try_select_default_model_no_keys(self, mock_check_tier):
|
||||||
"""Test no model is selected when no keys are present."""
|
"""Test no model is selected when no keys are present."""
|
||||||
self.assertIsNone(try_to_select_default_model())
|
self.assertIsNone(try_to_select_default_model())
|
||||||
mock_check_tier.assert_not_called()
|
mock_check_tier.assert_not_called()
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier', return_value=True) # Assume free tier
|
@patch("aider.onboarding.check_openrouter_tier", return_value=True) # Assume free tier
|
||||||
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key"}, clear=True)
|
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key"}, clear=True)
|
||||||
def test_try_select_default_model_openrouter_free(self, mock_check_tier):
|
def test_try_select_default_model_openrouter_free(self, mock_check_tier):
|
||||||
"""Test OpenRouter free model selection."""
|
"""Test OpenRouter free model selection."""
|
||||||
self.assertEqual(try_to_select_default_model(), "openrouter/google/gemini-2.5-pro-exp-03-25:free")
|
self.assertEqual(
|
||||||
|
try_to_select_default_model(), "openrouter/google/gemini-2.5-pro-exp-03-25:free"
|
||||||
|
)
|
||||||
mock_check_tier.assert_called_once_with("or_key")
|
mock_check_tier.assert_called_once_with("or_key")
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier', return_value=False) # Assume paid tier
|
@patch("aider.onboarding.check_openrouter_tier", return_value=False) # Assume paid tier
|
||||||
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key"}, clear=True)
|
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key"}, clear=True)
|
||||||
def test_try_select_default_model_openrouter_paid(self, mock_check_tier):
|
def test_try_select_default_model_openrouter_paid(self, mock_check_tier):
|
||||||
"""Test OpenRouter paid model selection."""
|
"""Test OpenRouter paid model selection."""
|
||||||
self.assertEqual(try_to_select_default_model(), "openrouter/anthropic/claude-3.7-sonnet")
|
self.assertEqual(try_to_select_default_model(), "openrouter/anthropic/claude-3.7-sonnet")
|
||||||
mock_check_tier.assert_called_once_with("or_key")
|
mock_check_tier.assert_called_once_with("or_key")
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier')
|
@patch("aider.onboarding.check_openrouter_tier")
|
||||||
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "an_key"}, clear=True)
|
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "an_key"}, clear=True)
|
||||||
def test_try_select_default_model_anthropic(self, mock_check_tier):
|
def test_try_select_default_model_anthropic(self, mock_check_tier):
|
||||||
"""Test Anthropic model selection."""
|
"""Test Anthropic model selection."""
|
||||||
self.assertEqual(try_to_select_default_model(), "sonnet")
|
self.assertEqual(try_to_select_default_model(), "sonnet")
|
||||||
mock_check_tier.assert_not_called()
|
mock_check_tier.assert_not_called()
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier')
|
@patch("aider.onboarding.check_openrouter_tier")
|
||||||
@patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds_key"}, clear=True)
|
@patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds_key"}, clear=True)
|
||||||
def test_try_select_default_model_deepseek(self, mock_check_tier):
|
def test_try_select_default_model_deepseek(self, mock_check_tier):
|
||||||
"""Test Deepseek model selection."""
|
"""Test Deepseek model selection."""
|
||||||
self.assertEqual(try_to_select_default_model(), "deepseek")
|
self.assertEqual(try_to_select_default_model(), "deepseek")
|
||||||
mock_check_tier.assert_not_called()
|
mock_check_tier.assert_not_called()
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier')
|
@patch("aider.onboarding.check_openrouter_tier")
|
||||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "oa_key"}, clear=True)
|
@patch.dict(os.environ, {"OPENAI_API_KEY": "oa_key"}, clear=True)
|
||||||
def test_try_select_default_model_openai(self, mock_check_tier):
|
def test_try_select_default_model_openai(self, mock_check_tier):
|
||||||
"""Test OpenAI model selection."""
|
"""Test OpenAI model selection."""
|
||||||
self.assertEqual(try_to_select_default_model(), "gpt-4o")
|
self.assertEqual(try_to_select_default_model(), "gpt-4o")
|
||||||
mock_check_tier.assert_not_called()
|
mock_check_tier.assert_not_called()
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier')
|
@patch("aider.onboarding.check_openrouter_tier")
|
||||||
@patch.dict(os.environ, {"GEMINI_API_KEY": "gm_key"}, clear=True)
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "gm_key"}, clear=True)
|
||||||
def test_try_select_default_model_gemini(self, mock_check_tier):
|
def test_try_select_default_model_gemini(self, mock_check_tier):
|
||||||
"""Test Gemini model selection."""
|
"""Test Gemini model selection."""
|
||||||
self.assertEqual(try_to_select_default_model(), "gemini/gemini-2.5-pro-exp-03-25")
|
self.assertEqual(try_to_select_default_model(), "gemini/gemini-2.5-pro-exp-03-25")
|
||||||
mock_check_tier.assert_not_called()
|
mock_check_tier.assert_not_called()
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier')
|
@patch("aider.onboarding.check_openrouter_tier")
|
||||||
@patch.dict(os.environ, {"VERTEXAI_PROJECT": "vx_proj"}, clear=True)
|
@patch.dict(os.environ, {"VERTEXAI_PROJECT": "vx_proj"}, clear=True)
|
||||||
def test_try_select_default_model_vertex(self, mock_check_tier):
|
def test_try_select_default_model_vertex(self, mock_check_tier):
|
||||||
"""Test Vertex AI model selection."""
|
"""Test Vertex AI model selection."""
|
||||||
self.assertEqual(try_to_select_default_model(), "vertex_ai/gemini-2.5-pro-exp-03-25")
|
self.assertEqual(try_to_select_default_model(), "vertex_ai/gemini-2.5-pro-exp-03-25")
|
||||||
mock_check_tier.assert_not_called()
|
mock_check_tier.assert_not_called()
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier', return_value=False) # Paid
|
@patch("aider.onboarding.check_openrouter_tier", return_value=False) # Paid
|
||||||
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "or_key", "OPENAI_API_KEY": "oa_key"}, clear=True)
|
@patch.dict(
|
||||||
|
os.environ, {"OPENROUTER_API_KEY": "or_key", "OPENAI_API_KEY": "oa_key"}, clear=True
|
||||||
|
)
|
||||||
def test_try_select_default_model_priority_openrouter(self, mock_check_tier):
|
def test_try_select_default_model_priority_openrouter(self, mock_check_tier):
|
||||||
"""Test OpenRouter key takes priority."""
|
"""Test OpenRouter key takes priority."""
|
||||||
self.assertEqual(try_to_select_default_model(), "openrouter/anthropic/claude-3.7-sonnet")
|
self.assertEqual(try_to_select_default_model(), "openrouter/anthropic/claude-3.7-sonnet")
|
||||||
mock_check_tier.assert_called_once_with("or_key")
|
mock_check_tier.assert_called_once_with("or_key")
|
||||||
|
|
||||||
@patch('aider.onboarding.check_openrouter_tier')
|
@patch("aider.onboarding.check_openrouter_tier")
|
||||||
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "an_key", "OPENAI_API_KEY": "oa_key"}, clear=True)
|
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "an_key", "OPENAI_API_KEY": "oa_key"}, clear=True)
|
||||||
def test_try_select_default_model_priority_anthropic(self, mock_check_tier):
|
def test_try_select_default_model_priority_anthropic(self, mock_check_tier):
|
||||||
"""Test Anthropic key takes priority over OpenAI."""
|
"""Test Anthropic key takes priority over OpenAI."""
|
||||||
self.assertEqual(try_to_select_default_model(), "sonnet")
|
self.assertEqual(try_to_select_default_model(), "sonnet")
|
||||||
mock_check_tier.assert_not_called()
|
mock_check_tier.assert_not_called()
|
||||||
|
|
||||||
@patch('socketserver.TCPServer')
|
@patch("socketserver.TCPServer")
|
||||||
def test_find_available_port_success(self, mock_tcp_server):
|
def test_find_available_port_success(self, mock_tcp_server):
|
||||||
"""Test finding an available port."""
|
"""Test finding an available port."""
|
||||||
# Simulate port 8484 being available
|
# Simulate port 8484 being available
|
||||||
mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager
|
mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager
|
||||||
port = find_available_port(start_port=8484, end_port=8484)
|
port = find_available_port(start_port=8484, end_port=8484)
|
||||||
self.assertEqual(port, 8484)
|
self.assertEqual(port, 8484)
|
||||||
mock_tcp_server.assert_called_once_with(("localhost", 8484), None)
|
mock_tcp_server.assert_called_once_with(("localhost", 8484), None)
|
||||||
|
|
||||||
@patch('socketserver.TCPServer')
|
@patch("socketserver.TCPServer")
|
||||||
def test_find_available_port_in_use(self, mock_tcp_server):
|
def test_find_available_port_in_use(self, mock_tcp_server):
|
||||||
"""Test finding the next available port if the first is in use."""
|
"""Test finding the next available port if the first is in use."""
|
||||||
# Simulate port 8484 raising OSError, 8485 being available
|
# Simulate port 8484 raising OSError, 8485 being available
|
||||||
mock_tcp_server.side_effect = [OSError, MagicMock()]
|
mock_tcp_server.side_effect = [OSError, MagicMock()]
|
||||||
mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager
|
mock_tcp_server.return_value.__enter__.return_value = None # Allow context manager
|
||||||
port = find_available_port(start_port=8484, end_port=8485)
|
port = find_available_port(start_port=8484, end_port=8485)
|
||||||
self.assertEqual(port, 8485)
|
self.assertEqual(port, 8485)
|
||||||
self.assertEqual(mock_tcp_server.call_count, 2)
|
self.assertEqual(mock_tcp_server.call_count, 2)
|
||||||
mock_tcp_server.assert_any_call(("localhost", 8484), None)
|
mock_tcp_server.assert_any_call(("localhost", 8484), None)
|
||||||
mock_tcp_server.assert_any_call(("localhost", 8485), None)
|
mock_tcp_server.assert_any_call(("localhost", 8485), None)
|
||||||
|
|
||||||
@patch('socketserver.TCPServer', side_effect=OSError)
|
@patch("socketserver.TCPServer", side_effect=OSError)
|
||||||
def test_find_available_port_none_available(self, mock_tcp_server):
|
def test_find_available_port_none_available(self, mock_tcp_server):
|
||||||
"""Test returning None if no ports are available in the range."""
|
"""Test returning None if no ports are available in the range."""
|
||||||
port = find_available_port(start_port=8484, end_port=8485)
|
port = find_available_port(start_port=8484, end_port=8485)
|
||||||
self.assertIsNone(port)
|
self.assertIsNone(port)
|
||||||
self.assertEqual(mock_tcp_server.call_count, 2) # Tried 8484 and 8485
|
self.assertEqual(mock_tcp_server.call_count, 2) # Tried 8484 and 8485
|
||||||
|
|
||||||
def test_generate_pkce_codes(self):
|
def test_generate_pkce_codes(self):
|
||||||
"""Test PKCE code generation."""
|
"""Test PKCE code generation."""
|
||||||
verifier, challenge = generate_pkce_codes()
|
verifier, challenge = generate_pkce_codes()
|
||||||
self.assertIsInstance(verifier, str)
|
self.assertIsInstance(verifier, str)
|
||||||
self.assertIsInstance(challenge, str)
|
self.assertIsInstance(challenge, str)
|
||||||
self.assertGreater(len(verifier), 40) # Check reasonable length
|
self.assertGreater(len(verifier), 40) # Check reasonable length
|
||||||
self.assertGreater(len(challenge), 40)
|
self.assertGreater(len(challenge), 40)
|
||||||
# Verify the challenge is the SHA256 hash of the verifier, base64 encoded
|
# Verify the challenge is the SHA256 hash of the verifier, base64 encoded
|
||||||
hasher = hashlib.sha256()
|
hasher = hashlib.sha256()
|
||||||
|
@ -184,7 +201,7 @@ class TestOnboarding(unittest.TestCase):
|
||||||
expected_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8")
|
expected_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8")
|
||||||
self.assertEqual(challenge, expected_challenge)
|
self.assertEqual(challenge, expected_challenge)
|
||||||
|
|
||||||
@patch('requests.post')
|
@patch("requests.post")
|
||||||
def test_exchange_code_for_key_success(self, mock_post):
|
def test_exchange_code_for_key_success(self, mock_post):
|
||||||
"""Test successful code exchange for API key."""
|
"""Test successful code exchange for API key."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
@ -207,16 +224,16 @@ class TestOnboarding(unittest.TestCase):
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('requests.post')
|
@patch("requests.post")
|
||||||
def test_exchange_code_for_key_missing_key(self, mock_post):
|
def test_exchange_code_for_key_missing_key(self, mock_post):
|
||||||
"""Test code exchange when 'key' is missing in response."""
|
"""Test code exchange when 'key' is missing in response."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {"other_data": "value"} # Missing 'key'
|
mock_response.json.return_value = {"other_data": "value"} # Missing 'key'
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_response.text = '{"other_data": "value"}'
|
mock_response.text = '{"other_data": "value"}'
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
io_mock = DummyIO()
|
io_mock = DummyIO()
|
||||||
io_mock.tool_error = MagicMock() # Track error output
|
io_mock.tool_error = MagicMock() # Track error output
|
||||||
|
|
||||||
api_key = exchange_code_for_key("auth_code", "verifier", io_mock)
|
api_key = exchange_code_for_key("auth_code", "verifier", io_mock)
|
||||||
|
|
||||||
|
@ -224,7 +241,7 @@ class TestOnboarding(unittest.TestCase):
|
||||||
io_mock.tool_error.assert_any_call("Error: 'key' not found in OpenRouter response.")
|
io_mock.tool_error.assert_any_call("Error: 'key' not found in OpenRouter response.")
|
||||||
io_mock.tool_error.assert_any_call('Response: {"other_data": "value"}')
|
io_mock.tool_error.assert_any_call('Response: {"other_data": "value"}')
|
||||||
|
|
||||||
@patch('requests.post')
|
@patch("requests.post")
|
||||||
def test_exchange_code_for_key_http_error(self, mock_post):
|
def test_exchange_code_for_key_http_error(self, mock_post):
|
||||||
"""Test code exchange with HTTP error."""
|
"""Test code exchange with HTTP error."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
@ -244,7 +261,7 @@ class TestOnboarding(unittest.TestCase):
|
||||||
)
|
)
|
||||||
io_mock.tool_error.assert_any_call('Response: {"error": "invalid_code"}')
|
io_mock.tool_error.assert_any_call('Response: {"error": "invalid_code"}')
|
||||||
|
|
||||||
@patch('requests.post')
|
@patch("requests.post")
|
||||||
def test_exchange_code_for_key_timeout(self, mock_post):
|
def test_exchange_code_for_key_timeout(self, mock_post):
|
||||||
"""Test code exchange with timeout."""
|
"""Test code exchange with timeout."""
|
||||||
mock_post.side_effect = requests.exceptions.Timeout("Timeout")
|
mock_post.side_effect = requests.exceptions.Timeout("Timeout")
|
||||||
|
@ -258,7 +275,7 @@ class TestOnboarding(unittest.TestCase):
|
||||||
"Error: Request to OpenRouter timed out during code exchange."
|
"Error: Request to OpenRouter timed out during code exchange."
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('requests.post')
|
@patch("requests.post")
|
||||||
def test_exchange_code_for_key_request_exception(self, mock_post):
|
def test_exchange_code_for_key_request_exception(self, mock_post):
|
||||||
"""Test code exchange with general request exception."""
|
"""Test code exchange with general request exception."""
|
||||||
req_exception = requests.exceptions.RequestException("Network Error")
|
req_exception = requests.exceptions.RequestException("Network Error")
|
||||||
|
@ -275,8 +292,8 @@ class TestOnboarding(unittest.TestCase):
|
||||||
|
|
||||||
# --- Tests for select_default_model ---
|
# --- Tests for select_default_model ---
|
||||||
|
|
||||||
@patch('aider.onboarding.try_to_select_default_model', return_value="gpt-4o")
|
@patch("aider.onboarding.try_to_select_default_model", return_value="gpt-4o")
|
||||||
@patch('aider.onboarding.offer_openrouter_oauth')
|
@patch("aider.onboarding.offer_openrouter_oauth")
|
||||||
def test_select_default_model_already_specified(self, mock_offer_oauth, mock_try_select):
|
def test_select_default_model_already_specified(self, mock_offer_oauth, mock_try_select):
|
||||||
"""Test select_default_model returns args.model if provided."""
|
"""Test select_default_model returns args.model if provided."""
|
||||||
args = argparse.Namespace(model="specific-model")
|
args = argparse.Namespace(model="specific-model")
|
||||||
|
@ -287,15 +304,15 @@ class TestOnboarding(unittest.TestCase):
|
||||||
mock_try_select.assert_not_called()
|
mock_try_select.assert_not_called()
|
||||||
mock_offer_oauth.assert_not_called()
|
mock_offer_oauth.assert_not_called()
|
||||||
|
|
||||||
@patch('aider.onboarding.try_to_select_default_model', return_value="gpt-4o")
|
@patch("aider.onboarding.try_to_select_default_model", return_value="gpt-4o")
|
||||||
@patch('aider.onboarding.offer_openrouter_oauth')
|
@patch("aider.onboarding.offer_openrouter_oauth")
|
||||||
def test_select_default_model_found_via_env(self, mock_offer_oauth, mock_try_select):
|
def test_select_default_model_found_via_env(self, mock_offer_oauth, mock_try_select):
|
||||||
"""Test select_default_model returns model found by try_to_select."""
|
"""Test select_default_model returns model found by try_to_select."""
|
||||||
args = argparse.Namespace(model=None) # No model specified
|
args = argparse.Namespace(model=None) # No model specified
|
||||||
io_mock = DummyIO()
|
io_mock = DummyIO()
|
||||||
io_mock.tool_warning = MagicMock() # Track warnings
|
io_mock.tool_warning = MagicMock() # Track warnings
|
||||||
analytics_mock = DummyAnalytics()
|
analytics_mock = DummyAnalytics()
|
||||||
analytics_mock.event = MagicMock() # Track events
|
analytics_mock.event = MagicMock() # Track events
|
||||||
|
|
||||||
selected_model = select_default_model(args, io_mock, analytics_mock)
|
selected_model = select_default_model(args, io_mock, analytics_mock)
|
||||||
|
|
||||||
|
@ -307,8 +324,12 @@ class TestOnboarding(unittest.TestCase):
|
||||||
analytics_mock.event.assert_called_once_with("auto_model_selection", model="gpt-4o")
|
analytics_mock.event.assert_called_once_with("auto_model_selection", model="gpt-4o")
|
||||||
mock_offer_oauth.assert_not_called()
|
mock_offer_oauth.assert_not_called()
|
||||||
|
|
||||||
@patch('aider.onboarding.try_to_select_default_model', side_effect=[None, None]) # Fails first, fails after oauth attempt
|
@patch(
|
||||||
@patch('aider.onboarding.offer_openrouter_oauth', return_value=False) # OAuth offered but fails/declined
|
"aider.onboarding.try_to_select_default_model", side_effect=[None, None]
|
||||||
|
) # Fails first, fails after oauth attempt
|
||||||
|
@patch(
|
||||||
|
"aider.onboarding.offer_openrouter_oauth", return_value=False
|
||||||
|
) # OAuth offered but fails/declined
|
||||||
def test_select_default_model_no_keys_oauth_fail(self, mock_offer_oauth, mock_try_select):
|
def test_select_default_model_no_keys_oauth_fail(self, mock_offer_oauth, mock_try_select):
|
||||||
"""Test select_default_model offers OAuth when no keys, but OAuth fails."""
|
"""Test select_default_model offers OAuth when no keys, but OAuth fails."""
|
||||||
args = argparse.Namespace(model=None)
|
args = argparse.Namespace(model=None)
|
||||||
|
@ -320,15 +341,20 @@ class TestOnboarding(unittest.TestCase):
|
||||||
selected_model = select_default_model(args, io_mock, analytics_mock)
|
selected_model = select_default_model(args, io_mock, analytics_mock)
|
||||||
|
|
||||||
self.assertIsNone(selected_model)
|
self.assertIsNone(selected_model)
|
||||||
self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth attempt
|
self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth attempt
|
||||||
mock_offer_oauth.assert_called_once_with(io_mock, analytics_mock)
|
mock_offer_oauth.assert_called_once_with(io_mock, analytics_mock)
|
||||||
io_mock.tool_warning.assert_called_once_with(
|
io_mock.tool_warning.assert_called_once_with(
|
||||||
"No LLM model was specified and no API keys were provided."
|
"No LLM model was specified and no API keys were provided."
|
||||||
)
|
)
|
||||||
io_mock.offer_url.assert_called_once() # Should offer docs URL
|
io_mock.offer_url.assert_called_once() # Should offer docs URL
|
||||||
|
|
||||||
@patch('aider.onboarding.try_to_select_default_model', side_effect=[None, "openrouter/google/gemini-2.5-pro-exp-03-25:free"]) # Fails first, succeeds after oauth
|
@patch(
|
||||||
@patch('aider.onboarding.offer_openrouter_oauth', return_value=True) # OAuth offered and succeeds
|
"aider.onboarding.try_to_select_default_model",
|
||||||
|
side_effect=[None, "openrouter/google/gemini-2.5-pro-exp-03-25:free"],
|
||||||
|
) # Fails first, succeeds after oauth
|
||||||
|
@patch(
|
||||||
|
"aider.onboarding.offer_openrouter_oauth", return_value=True
|
||||||
|
) # OAuth offered and succeeds
|
||||||
def test_select_default_model_no_keys_oauth_success(self, mock_offer_oauth, mock_try_select):
|
def test_select_default_model_no_keys_oauth_success(self, mock_offer_oauth, mock_try_select):
|
||||||
"""Test select_default_model offers OAuth, which succeeds."""
|
"""Test select_default_model offers OAuth, which succeeds."""
|
||||||
args = argparse.Namespace(model=None)
|
args = argparse.Namespace(model=None)
|
||||||
|
@ -339,7 +365,7 @@ class TestOnboarding(unittest.TestCase):
|
||||||
selected_model = select_default_model(args, io_mock, analytics_mock)
|
selected_model = select_default_model(args, io_mock, analytics_mock)
|
||||||
|
|
||||||
self.assertEqual(selected_model, "openrouter/google/gemini-2.5-pro-exp-03-25:free")
|
self.assertEqual(selected_model, "openrouter/google/gemini-2.5-pro-exp-03-25:free")
|
||||||
self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth
|
self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth
|
||||||
mock_offer_oauth.assert_called_once_with(io_mock, analytics_mock)
|
mock_offer_oauth.assert_called_once_with(io_mock, analytics_mock)
|
||||||
# First warning about no keys, second about using the model found after OAuth
|
# First warning about no keys, second about using the model found after OAuth
|
||||||
self.assertEqual(io_mock.tool_warning.call_count, 2)
|
self.assertEqual(io_mock.tool_warning.call_count, 2)
|
||||||
|
@ -353,12 +379,12 @@ class TestOnboarding(unittest.TestCase):
|
||||||
# Instead, let's verify the final state and model returned.
|
# Instead, let's verify the final state and model returned.
|
||||||
|
|
||||||
# --- Tests for offer_openrouter_oauth ---
|
# --- Tests for offer_openrouter_oauth ---
|
||||||
@patch('aider.onboarding.start_openrouter_oauth_flow', return_value="new_or_key")
|
@patch("aider.onboarding.start_openrouter_oauth_flow", return_value="new_or_key")
|
||||||
@patch.dict(os.environ, {}, clear=True) # Ensure no key exists initially
|
@patch.dict(os.environ, {}, clear=True) # Ensure no key exists initially
|
||||||
def test_offer_openrouter_oauth_confirm_yes_success(self, mock_start_oauth):
|
def test_offer_openrouter_oauth_confirm_yes_success(self, mock_start_oauth):
|
||||||
"""Test offer_openrouter_oauth when user confirms and OAuth succeeds."""
|
"""Test offer_openrouter_oauth when user confirms and OAuth succeeds."""
|
||||||
io_mock = DummyIO()
|
io_mock = DummyIO()
|
||||||
io_mock.confirm_ask = MagicMock(return_value=True) # User says yes
|
io_mock.confirm_ask = MagicMock(return_value=True) # User says yes
|
||||||
analytics_mock = DummyAnalytics()
|
analytics_mock = DummyAnalytics()
|
||||||
analytics_mock.event = MagicMock()
|
analytics_mock.event = MagicMock()
|
||||||
|
|
||||||
|
@ -373,12 +399,12 @@ class TestOnboarding(unittest.TestCase):
|
||||||
# Clean up env var
|
# Clean up env var
|
||||||
del os.environ["OPENROUTER_API_KEY"]
|
del os.environ["OPENROUTER_API_KEY"]
|
||||||
|
|
||||||
@patch('aider.onboarding.start_openrouter_oauth_flow', return_value=None) # OAuth fails
|
@patch("aider.onboarding.start_openrouter_oauth_flow", return_value=None) # OAuth fails
|
||||||
@patch.dict(os.environ, {}, clear=True)
|
@patch.dict(os.environ, {}, clear=True)
|
||||||
def test_offer_openrouter_oauth_confirm_yes_fail(self, mock_start_oauth):
|
def test_offer_openrouter_oauth_confirm_yes_fail(self, mock_start_oauth):
|
||||||
"""Test offer_openrouter_oauth when user confirms but OAuth fails."""
|
"""Test offer_openrouter_oauth when user confirms but OAuth fails."""
|
||||||
io_mock = DummyIO()
|
io_mock = DummyIO()
|
||||||
io_mock.confirm_ask = MagicMock(return_value=True) # User says yes
|
io_mock.confirm_ask = MagicMock(return_value=True) # User says yes
|
||||||
io_mock.tool_error = MagicMock()
|
io_mock.tool_error = MagicMock()
|
||||||
analytics_mock = DummyAnalytics()
|
analytics_mock = DummyAnalytics()
|
||||||
analytics_mock.event = MagicMock()
|
analytics_mock.event = MagicMock()
|
||||||
|
@ -389,16 +415,17 @@ class TestOnboarding(unittest.TestCase):
|
||||||
io_mock.confirm_ask.assert_called_once()
|
io_mock.confirm_ask.assert_called_once()
|
||||||
mock_start_oauth.assert_called_once_with(io_mock, analytics_mock)
|
mock_start_oauth.assert_called_once_with(io_mock, analytics_mock)
|
||||||
self.assertNotIn("OPENROUTER_API_KEY", os.environ)
|
self.assertNotIn("OPENROUTER_API_KEY", os.environ)
|
||||||
io_mock.tool_error.assert_called_once_with("OpenRouter authentication did not complete successfully.")
|
io_mock.tool_error.assert_called_once_with(
|
||||||
|
"OpenRouter authentication did not complete successfully."
|
||||||
|
)
|
||||||
analytics_mock.event.assert_any_call("oauth_flow_initiated", provider="openrouter")
|
analytics_mock.event.assert_any_call("oauth_flow_initiated", provider="openrouter")
|
||||||
analytics_mock.event.assert_any_call("oauth_flow_failure")
|
analytics_mock.event.assert_any_call("oauth_flow_failure")
|
||||||
|
|
||||||
|
@patch("aider.onboarding.start_openrouter_oauth_flow")
|
||||||
@patch('aider.onboarding.start_openrouter_oauth_flow')
|
|
||||||
def test_offer_openrouter_oauth_confirm_no(self, mock_start_oauth):
|
def test_offer_openrouter_oauth_confirm_no(self, mock_start_oauth):
|
||||||
"""Test offer_openrouter_oauth when user declines."""
|
"""Test offer_openrouter_oauth when user declines."""
|
||||||
io_mock = DummyIO()
|
io_mock = DummyIO()
|
||||||
io_mock.confirm_ask = MagicMock(return_value=False) # User says no
|
io_mock.confirm_ask = MagicMock(return_value=False) # User says no
|
||||||
analytics_mock = DummyAnalytics()
|
analytics_mock = DummyAnalytics()
|
||||||
analytics_mock.event = MagicMock()
|
analytics_mock.event = MagicMock()
|
||||||
|
|
||||||
|
@ -407,23 +434,32 @@ class TestOnboarding(unittest.TestCase):
|
||||||
self.assertFalse(result)
|
self.assertFalse(result)
|
||||||
io_mock.confirm_ask.assert_called_once()
|
io_mock.confirm_ask.assert_called_once()
|
||||||
mock_start_oauth.assert_not_called()
|
mock_start_oauth.assert_not_called()
|
||||||
analytics_mock.event.assert_not_called() # No OAuth events if declined
|
analytics_mock.event.assert_not_called() # No OAuth events if declined
|
||||||
|
|
||||||
|
|
||||||
# --- More complex test for start_openrouter_oauth_flow (simplified) ---
|
# --- More complex test for start_openrouter_oauth_flow (simplified) ---
|
||||||
# This test focuses on the successful path, mocking heavily
|
# This test focuses on the successful path, mocking heavily
|
||||||
@patch('aider.onboarding.check_pip_install_extra', return_value=True) # Assume requests is installed
|
@patch(
|
||||||
@patch('aider.onboarding.find_available_port', return_value=8484)
|
"aider.onboarding.check_pip_install_extra", return_value=True
|
||||||
@patch('threading.Thread')
|
) # Assume requests is installed
|
||||||
@patch('threading.Event')
|
@patch("aider.onboarding.find_available_port", return_value=8484)
|
||||||
@patch('webbrowser.open')
|
@patch("threading.Thread")
|
||||||
@patch('aider.onboarding.exchange_code_for_key', return_value="oauth_api_key")
|
@patch("threading.Event")
|
||||||
@patch('os.makedirs')
|
@patch("webbrowser.open")
|
||||||
@patch('builtins.open', new_callable=mock_open)
|
@patch("aider.onboarding.exchange_code_for_key", return_value="oauth_api_key")
|
||||||
@patch.dict(os.environ, {}, clear=True) # Start with clean env
|
@patch("os.makedirs")
|
||||||
|
@patch("builtins.open", new_callable=mock_open)
|
||||||
|
@patch.dict(os.environ, {}, clear=True) # Start with clean env
|
||||||
def test_start_openrouter_oauth_flow_success_path(
|
def test_start_openrouter_oauth_flow_success_path(
|
||||||
self, mock_env, mock_open_file, mock_makedirs, mock_exchange, mock_webbrowser,
|
self,
|
||||||
mock_event_cls, mock_thread_cls, mock_find_port, mock_check_pip
|
mock_env,
|
||||||
|
mock_open_file,
|
||||||
|
mock_makedirs,
|
||||||
|
mock_exchange,
|
||||||
|
mock_webbrowser,
|
||||||
|
mock_event_cls,
|
||||||
|
mock_thread_cls,
|
||||||
|
mock_find_port,
|
||||||
|
mock_check_pip,
|
||||||
):
|
):
|
||||||
"""Test the successful path of start_openrouter_oauth_flow."""
|
"""Test the successful path of start_openrouter_oauth_flow."""
|
||||||
io_mock = DummyIO()
|
io_mock = DummyIO()
|
||||||
|
@ -432,10 +468,10 @@ class TestOnboarding(unittest.TestCase):
|
||||||
|
|
||||||
# Mock threading Events: pretend server starts and callback happens quickly
|
# Mock threading Events: pretend server starts and callback happens quickly
|
||||||
mock_server_started_event = MagicMock()
|
mock_server_started_event = MagicMock()
|
||||||
mock_server_started_event.wait.return_value = True # Server started
|
mock_server_started_event.wait.return_value = True # Server started
|
||||||
mock_shutdown_event = MagicMock()
|
mock_shutdown_event = MagicMock()
|
||||||
mock_shutdown_event.is_set.side_effect = [False, True] # Loop once, then shutdown
|
mock_shutdown_event.is_set.side_effect = [False, True] # Loop once, then shutdown
|
||||||
mock_shutdown_event.wait.return_value = True # Callback received before timeout
|
mock_shutdown_event.wait.return_value = True # Callback received before timeout
|
||||||
|
|
||||||
# Need to simulate the callback setting the auth_code *within* the flow
|
# Need to simulate the callback setting the auth_code *within* the flow
|
||||||
# This is tricky because it happens in a separate thread in reality.
|
# This is tricky because it happens in a separate thread in reality.
|
||||||
|
@ -459,17 +495,19 @@ class TestOnboarding(unittest.TestCase):
|
||||||
self.assertEqual(api_key, "oauth_api_key")
|
self.assertEqual(api_key, "oauth_api_key")
|
||||||
mock_check_pip.assert_called_once()
|
mock_check_pip.assert_called_once()
|
||||||
mock_find_port.assert_called_once()
|
mock_find_port.assert_called_once()
|
||||||
mock_thread_cls.assert_called_once() # Server thread created
|
mock_thread_cls.assert_called_once() # Server thread created
|
||||||
mock_server_thread.start.assert_called_once() # Server thread started
|
mock_server_thread.start.assert_called_once() # Server thread started
|
||||||
mock_server_started_event.wait.assert_called_once_with(timeout=5) # Wait for server start
|
mock_server_started_event.wait.assert_called_once_with(timeout=5) # Wait for server start
|
||||||
mock_webbrowser.assert_called_once() # Browser should be opened
|
mock_webbrowser.assert_called_once() # Browser should be opened
|
||||||
mock_shutdown_event.wait.assert_called_once_with(timeout=300) # Wait for callback/timeout
|
mock_shutdown_event.wait.assert_called_once_with(timeout=300) # Wait for callback/timeout
|
||||||
mock_exchange.assert_called_once() # Code exchange attempted
|
mock_exchange.assert_called_once() # Code exchange attempted
|
||||||
# Check that the key was set in the environment
|
# Check that the key was set in the environment
|
||||||
self.assertEqual(os.environ.get("OPENROUTER_API_KEY"), "oauth_api_key")
|
self.assertEqual(os.environ.get("OPENROUTER_API_KEY"), "oauth_api_key")
|
||||||
# Check that saving the key was attempted
|
# Check that saving the key was attempted
|
||||||
mock_makedirs.assert_called_once()
|
mock_makedirs.assert_called_once()
|
||||||
mock_open_file.assert_called_once_with(os.path.expanduser("~/.aider/oauth-keys.env"), "a", encoding="utf-8")
|
mock_open_file.assert_called_once_with(
|
||||||
|
os.path.expanduser("~/.aider/oauth-keys.env"), "a", encoding="utf-8"
|
||||||
|
)
|
||||||
mock_open_file().write.assert_called_once_with('OPENROUTER_API_KEY="oauth_api_key"\n')
|
mock_open_file().write.assert_called_once_with('OPENROUTER_API_KEY="oauth_api_key"\n')
|
||||||
# Check analytics events
|
# Check analytics events
|
||||||
analytics_mock.event.assert_any_call("oauth_flow_code_received", provider="openrouter")
|
analytics_mock.event.assert_any_call("oauth_flow_code_received", provider="openrouter")
|
||||||
|
@ -478,5 +516,5 @@ class TestOnboarding(unittest.TestCase):
|
||||||
del os.environ["OPENROUTER_API_KEY"]
|
del os.environ["OPENROUTER_API_KEY"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue