diff --git a/tests/test_models.py b/tests/test_models.py index 48e797b12..fe8b681dc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch +from unittest.mock import MagicMock from aider.models import Model, OpenRouterModel @@ -27,14 +27,9 @@ class TestModels(unittest.TestCase): model = Model.create("gpt-4-32k-2123") self.assertEqual(model.max_context_tokens, 32 * 1024) - @patch("openai.resources.Models.list") - def test_openrouter_model_properties(self, mock_model_list): - # import openai - - # old_base = openai.api_base - # TODO: fixme - # openai.api_base = "https://openrouter.ai/api/v1" - mock_model_list.return_value = { + def test_openrouter_model_properties(self): + client = MagicMock() + client.models.list.return_value = { "data": [ { "id": "openai/gpt-4", @@ -44,17 +39,15 @@ class TestModels(unittest.TestCase): } ] } - mock_model_list.return_value = type( - "", (), {"data": mock_model_list.return_value["data"]} + client.models.list.return_value = type( + "", (), {"data": client.models.list.return_value["data"]} )() - model = OpenRouterModel("gpt-4") + model = OpenRouterModel(client, "gpt-4") self.assertEqual(model.name, "openai/gpt-4") self.assertEqual(model.max_context_tokens, 8192) self.assertEqual(model.prompt_price, 0.06) self.assertEqual(model.completion_price, 0.12) - # TODO: fixme - # openai.api_base = old_base if __name__ == "__main__":