diff --git a/aider/models.py b/aider/models.py index 304232259..df28ae0e9 100644 --- a/aider/models.py +++ b/aider/models.py @@ -1,19 +1,31 @@ import re +known_tokens = { + "gpt-3.5-turbo": 4, + "gpt-4": 8, +} + class Model: always_available = False use_repo_map = False send_undo_reply = False - def __init__(self, name, tokens=None): + def __init__(self, name): self.name = name + + tokens = None + + match = re.search(r"-([0-9]+)k", name) + if match: + tokens = int(match.group(1)) + else: + for m, t in known_tokens.items(): + if name.startswith(m): + tokens = t + if tokens is None: - match = re.search(r"-([0-9]+)k", name) - - default_tokens = 8 - - tokens = int(match.group(1)) if match else default_tokens + raise ValueError(f"Unknown context window size for model: {name}") self.max_context_tokens = tokens * 1024 @@ -40,6 +52,6 @@ class Model: return self.name -GPT4 = Model("gpt-4", 8) +GPT4 = Model("gpt-4") GPT35 = Model("gpt-3.5-turbo") GPT35_16k = Model("gpt-3.5-turbo-16k") diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 000000000..af2a6f8d7 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,28 @@ +import unittest + +from aider.models import Model + + +class TestModels(unittest.TestCase): + def test_max_context_tokens(self): + model = Model("gpt-3.5-turbo") + self.assertEqual(model.max_context_tokens, 4 * 1024) + + model = Model("gpt-3.5-turbo-16k") + self.assertEqual(model.max_context_tokens, 16 * 1024) + + model = Model("gpt-4") + self.assertEqual(model.max_context_tokens, 8 * 1024) + + model = Model("gpt-4-32k") + self.assertEqual(model.max_context_tokens, 32 * 1024) + + model = Model("gpt-4-0101") + self.assertEqual(model.max_context_tokens, 8 * 1024) + + model = Model("gpt-4-32k-2123") + self.assertEqual(model.max_context_tokens, 32 * 1024) + + +if __name__ == "__main__": + unittest.main()