diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index 88026cfb0..146ddcb6c 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -974,6 +974,74 @@ This command will print 'Hello, World!' to the console.""" self.assertIn("Output tokens:", error_message) self.assertIn("Total tokens:", error_message) + def test_keyboard_interrupt_handling(self): + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io) + + # Simulate keyboard interrupt during message processing + def mock_send(*args, **kwargs): + raise KeyboardInterrupt() + + coder.send = mock_send + + # Initial valid state + coder.cur_messages = [{"role": "user", "content": "Initial question"}] + self.assertTrue(sanity_check_messages(coder.cur_messages)) + + # Process message that will trigger interrupt + with self.assertRaises(KeyboardInterrupt): + list(coder.send_message("Test message")) + + # Verify messages are still in valid state + self.assertTrue(sanity_check_messages(coder.cur_messages)) + self.assertEqual(len(coder.cur_messages), 2) + self.assertEqual(coder.cur_messages[-1]["role"], "user") + + def test_token_limit_error_handling(self): + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io) + + # Simulate token limit error + def mock_send(*args, **kwargs): + raise FinishReasonLength() + + coder.send = mock_send + + # Initial valid state + coder.cur_messages = [{"role": "user", "content": "Initial question"}] + self.assertTrue(sanity_check_messages(coder.cur_messages)) + + # Process message that hits token limit + list(coder.send_message("Long message")) + + # Verify messages are still in valid state + self.assertTrue(sanity_check_messages(coder.cur_messages)) + self.assertEqual(coder.cur_messages[-1]["role"], "user") + + def test_message_sanity_after_partial_response(self): + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io) + + # Simulate partial response then interrupt + def mock_send(*args, **kwargs): + coder.partial_response_content = "Partial response" + raise KeyboardInterrupt() + + coder.send = mock_send + + coder.cur_messages = [{"role": "user", "content": "Question"}] + with self.assertRaises(KeyboardInterrupt): + list(coder.send_message("Test")) + + # Verify message structure remains valid + self.assertTrue(sanity_check_messages(coder.cur_messages)) + self.assertEqual(len(coder.cur_messages), 2) + self.assertEqual(coder.cur_messages[-1]["role"], "user") + self.assertIn("Partial response", coder.partial_response_content) + if __name__ == "__main__": unittest.main()