diff --git a/aider/models/openrouter.py b/aider/models/openrouter.py index 4f6c9e1dc..59a099876 100644 --- a/aider/models/openrouter.py +++ b/aider/models/openrouter.py @@ -28,9 +28,9 @@ class OpenRouterModel(Model): found = next((details for details in cached_model_details if details.get('id') == name), None) if found: - self.max_context_tokens = int(found.context_length) - self.prompt_price = float(found.get('pricing').get('prompt')) * 1000 - self.completion_price = float(found.get('pricing').get('completion')) * 1000 + self.max_context_tokens = int(found.get('context_length')) + self.prompt_price = round(float(found.get('pricing').get('prompt')) * 1000,6) + self.completion_price = round(float(found.get('pricing').get('completion')) * 1000,6) else: raise ValueError(f'invalid openrouter model: {name}') diff --git a/tests/test_models.py b/tests/test_models.py index 1b826b29b..04f8bce81 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,7 @@ import unittest +from unittest.mock import patch -from aider.models import Model +from aider.models import Model, OpenRouterModel class TestModels(unittest.TestCase): @@ -23,12 +24,31 @@ class TestModels(unittest.TestCase): model = Model.create("gpt-4-32k-2123") self.assertEqual(model.max_context_tokens, 32 * 1024) - def test_openrouter_models(self): + + @patch('openai.Model.list') + def test_openrouter_model_properties(self, mock_model_list): import openai openai.api_base = 'https://openrouter.ai/api/v1' - model = Model.create("gpt-3.5-turbo") - self.assertEqual(model.name, 'openai/gpt-3.5-turbo') + mock_model_list.return_value = { + 'data': [ + { + 'id': 'openai/gpt-4', + 'object': 'model', + 'context_length': '8192', + 'pricing': { + 'prompt': '0.00006', + 'completion': '0.00012' + } + } + ] + } + mock_model_list.return_value = type('', (), {'data': mock_model_list.return_value['data']})() + model = OpenRouterModel("gpt-4") + self.assertEqual(model.name, 'openai/gpt-4') + self.assertEqual(model.max_context_tokens, 8192) + self.assertEqual(model.prompt_price, 0.06) + self.assertEqual(model.completion_price, 0.12) if __name__ == "__main__": unittest.main()