diff --git a/tests/basic/test_models.py b/tests/basic/test_models.py index 968481319..f5b6de510 100644 --- a/tests/basic/test_models.py +++ b/tests/basic/test_models.py @@ -159,6 +159,51 @@ class TestModels(unittest.TestCase): model = Model("github/o1-preview") self.assertEqual(model.name, "github/o1-preview") self.assertEqual(model.use_temperature, False) + + def test_parse_token_value(self): + # Create a model instance to test the parse_token_value method + model = Model("gpt-4") + + # Test integer inputs + self.assertEqual(model.parse_token_value(8096), 8096) + self.assertEqual(model.parse_token_value(1000), 1000) + + # Test string inputs + self.assertEqual(model.parse_token_value("8096"), 8096) + + # Test k/K suffix (kilobytes) + self.assertEqual(model.parse_token_value("8k"), 8 * 1024) + self.assertEqual(model.parse_token_value("8K"), 8 * 1024) + self.assertEqual(model.parse_token_value("10.5k"), 10.5 * 1024) + self.assertEqual(model.parse_token_value("0.5K"), 0.5 * 1024) + + # Test m/M suffix (megabytes) + self.assertEqual(model.parse_token_value("1m"), 1 * 1024 * 1024) + self.assertEqual(model.parse_token_value("1M"), 1 * 1024 * 1024) + self.assertEqual(model.parse_token_value("0.5M"), 0.5 * 1024 * 1024) + + # Test with spaces + self.assertEqual(model.parse_token_value(" 8k "), 8 * 1024) + + # Test conversion from other types + self.assertEqual(model.parse_token_value(8.0), 8) + + def test_set_thinking_tokens(self): + # Test that set_thinking_tokens correctly sets the tokens with different formats + model = Model("gpt-4") + + # Test with integer + model.set_thinking_tokens(8096) + self.assertEqual(model.extra_params["thinking"]["budget_tokens"], 8096) + self.assertFalse(model.use_temperature) + + # Test with string + model.set_thinking_tokens("10k") + self.assertEqual(model.extra_params["thinking"]["budget_tokens"], 10 * 1024) + + # Test with decimal value + model.set_thinking_tokens("0.5M") + self.assertEqual(model.extra_params["thinking"]["budget_tokens"], 0.5 * 1024 * 1024) @patch("aider.models.check_pip_install_extra") def test_check_for_dependencies_bedrock(self, mock_check_pip):