diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index 294cbe4ee..731b821cd 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -747,6 +747,40 @@ two self.assertEqual(len(coder1.abs_fnames), 1) self.assertEqual(len(coder2.abs_fnames), 1) + def test_suggest_shell_commands(self): + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, None, io=io) + + def mock_send(*args, **kwargs): + coder.partial_response_content = """Here's a shell command to run: + +```bash +echo "Hello, World!" +``` + +This command will print 'Hello, World!' to the console.""" + coder.partial_response_function_call = dict() + return [] + + coder.send = mock_send + + # Run the coder with a message + coder.run(with_message="Suggest a shell command") + + # Check if the shell command was added to the list + self.assertEqual(len(coder.shell_commands), 1) + self.assertEqual(coder.shell_commands[0], 'echo "Hello, World!"') + + # Mock the handle_shell_commands method to check if it's called + coder.handle_shell_commands = MagicMock() + + # Run the coder again to trigger the shell command execution + coder.run(with_message="Run the suggested command") + + # Check if handle_shell_commands was called with the correct argument + coder.handle_shell_commands.assert_called_once_with('echo "Hello, World!"', ANY) + if __name__ == "__main__": unittest.main()