diff --git a/tests/basic/test_io.py b/tests/basic/test_io.py index 63692e098..b577507ef 100644 --- a/tests/basic/test_io.py +++ b/tests/basic/test_io.py @@ -353,6 +353,46 @@ class TestInputOutputMultilineMode(unittest.TestCase): # The invalid Unicode should be replaced with '?' self.assertEqual(converted_message, "Hello ?World") + def test_multiline_mode_restored_after_interrupt(self): + """Test that multiline mode is restored after KeyboardInterrupt""" + io = InputOutput(fancy_input=True) + io.prompt_session = MagicMock() + + # Start in multiline mode + io.multiline_mode = True + + # Mock prompt() to raise KeyboardInterrupt + io.prompt_session.prompt.side_effect = KeyboardInterrupt + + # Test confirm_ask() + with self.assertRaises(KeyboardInterrupt): + io.confirm_ask("Test question?") + self.assertTrue(io.multiline_mode) # Should be restored + + # Test prompt_ask() + with self.assertRaises(KeyboardInterrupt): + io.prompt_ask("Test prompt?") + self.assertTrue(io.multiline_mode) # Should be restored + + def test_multiline_mode_restored_after_normal_exit(self): + """Test that multiline mode is restored after normal exit""" + io = InputOutput(fancy_input=True) + io.prompt_session = MagicMock() + + # Start in multiline mode + io.multiline_mode = True + + # Mock prompt() to return normally + io.prompt_session.prompt.return_value = "y" + + # Test confirm_ask() + io.confirm_ask("Test question?") + self.assertTrue(io.multiline_mode) # Should be restored + + # Test prompt_ask() + io.prompt_ask("Test prompt?") + self.assertTrue(io.multiline_mode) # Should be restored + if __name__ == "__main__": unittest.main()