test: Improve summary test with list-aware token counting mock

This commit is contained in:
Paul Gauthier 2025-06-25 12:57:59 -07:00 committed by Paul Gauthier (aider)
parent 19a7864168
commit 323910be11

View file

@ -4,11 +4,17 @@ from aider.history import ChatSummary
from aider.models import Model from aider.models import Model
def count(msg):
if isinstance(msg, list):
return sum(count(m) for m in msg)
return len(msg["content"].split())
class TestChatSummary(TestCase): class TestChatSummary(TestCase):
def setUp(self): def setUp(self):
self.mock_model = mock.Mock(spec=Model) self.mock_model = mock.Mock(spec=Model)
self.mock_model.name = "gpt-3.5-turbo" self.mock_model.name = "gpt-3.5-turbo"
self.mock_model.token_count = lambda msg: len(msg["content"].split()) self.mock_model.token_count = count
self.mock_model.info = {"max_input_tokens": 4096} self.mock_model.info = {"max_input_tokens": 4096}
self.mock_model.simple_send_with_retries = mock.Mock() self.mock_model.simple_send_with_retries = mock.Mock()
self.chat_summary = ChatSummary(self.mock_model, max_tokens=100) self.chat_summary = ChatSummary(self.mock_model, max_tokens=100)
@ -55,8 +61,11 @@ class TestChatSummary(TestCase):
) )
def test_summarize(self): def test_summarize(self):
messages = [{"role": "user", "content": f"Message {i}"} for i in range(10)] N = 100
messages.extend([{"role": "assistant", "content": f"Response {i}"} for i in range(10)]) messages = [None] * (2 * N)
for i in range(N):
messages[2 * i] = {"role": "user", "content": f"Message {i}"}
messages[2 * i + 1] = {"role": "assistant", "content": f"Response {i}"}
with mock.patch.object( with mock.patch.object(
self.chat_summary, self.chat_summary,
@ -65,9 +74,11 @@ class TestChatSummary(TestCase):
): ):
result = self.chat_summary.summarize(messages) result = self.chat_summary.summarize(messages)
print(result)
self.assertIsInstance(result, list) self.assertIsInstance(result, list)
self.assertGreater(len(result), 0) self.assertGreater(len(result), 0)
self.assertLessEqual(len(result), len(messages)) self.assertLess(len(result), len(messages))
self.assertEqual(result[0]["content"], "Summary")
def test_fallback_to_second_model(self): def test_fallback_to_second_model(self):
mock_model1 = mock.Mock(spec=Model) mock_model1 = mock.Mock(spec=Model)