refactor: Simplify test assertions and remove redundant checks

This commit is contained in:
Paul Gauthier 2025-01-22 10:17:21 -08:00 committed by Paul Gauthier (aider)
parent 40ee3b1b45
commit 30b150dbfc

View file

@ -982,22 +982,21 @@ This command will print 'Hello, World!' to the console."""
# Simulate keyboard interrupt during message processing # Simulate keyboard interrupt during message processing
def mock_send(*args, **kwargs): def mock_send(*args, **kwargs):
coder.partial_response_content = "Partial response"
coder.partial_response_function_call = dict()
raise KeyboardInterrupt() raise KeyboardInterrupt()
coder.send = mock_send coder.send = mock_send
# Initial valid state # Initial valid state
coder.cur_messages = [{"role": "user", "content": "Initial question"}] sanity_check_messages(coder.cur_messages)
self.assertTrue(sanity_check_messages(coder.cur_messages))
# Process message that will trigger interrupt # Process message that will trigger interrupt
with self.assertRaises(KeyboardInterrupt): list(coder.send_message("Test message"))
list(coder.send_message("Test message"))
# Verify messages are still in valid state # Verify messages are still in valid state
self.assertTrue(sanity_check_messages(coder.cur_messages)) sanity_check_messages(coder.cur_messages)
self.assertEqual(len(coder.cur_messages), 2) self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
self.assertEqual(coder.cur_messages[-1]["role"], "user")
def test_token_limit_error_handling(self): def test_token_limit_error_handling(self):
with GitTemporaryDirectory(): with GitTemporaryDirectory():
@ -1006,20 +1005,21 @@ This command will print 'Hello, World!' to the console."""
# Simulate token limit error # Simulate token limit error
def mock_send(*args, **kwargs): def mock_send(*args, **kwargs):
coder.partial_response_content = "Partial response"
coder.partial_response_function_call = dict()
raise FinishReasonLength() raise FinishReasonLength()
coder.send = mock_send coder.send = mock_send
# Initial valid state # Initial valid state
coder.cur_messages = [{"role": "user", "content": "Initial question"}] sanity_check_messages(coder.cur_messages)
self.assertTrue(sanity_check_messages(coder.cur_messages))
# Process message that hits token limit # Process message that hits token limit
list(coder.send_message("Long message")) list(coder.send_message("Long message"))
# Verify messages are still in valid state # Verify messages are still in valid state
self.assertTrue(sanity_check_messages(coder.cur_messages)) sanity_check_messages(coder.cur_messages)
self.assertEqual(coder.cur_messages[-1]["role"], "user") self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
def test_message_sanity_after_partial_response(self): def test_message_sanity_after_partial_response(self):
with GitTemporaryDirectory(): with GitTemporaryDirectory():
@ -1029,19 +1029,16 @@ This command will print 'Hello, World!' to the console."""
# Simulate partial response then interrupt # Simulate partial response then interrupt
def mock_send(*args, **kwargs): def mock_send(*args, **kwargs):
coder.partial_response_content = "Partial response" coder.partial_response_content = "Partial response"
coder.partial_response_function_call = dict()
raise KeyboardInterrupt() raise KeyboardInterrupt()
coder.send = mock_send coder.send = mock_send
coder.cur_messages = [{"role": "user", "content": "Question"}] list(coder.send_message("Test"))
with self.assertRaises(KeyboardInterrupt):
list(coder.send_message("Test"))
# Verify message structure remains valid # Verify message structure remains valid
self.assertTrue(sanity_check_messages(coder.cur_messages)) sanity_check_messages(coder.cur_messages)
self.assertEqual(len(coder.cur_messages), 2) self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
self.assertEqual(coder.cur_messages[-1]["role"], "user")
self.assertIn("Partial response", coder.partial_response_content)
if __name__ == "__main__": if __name__ == "__main__":