diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index e72baf1ab..25531d421 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -846,6 +846,44 @@ This command will print 'Hello, World!' to the console.""" # Check if handle_shell_commands was called with the correct argument coder.handle_shell_commands.assert_not_called() + def test_show_exhausted_error(self): + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io) + + # Set up some real done_messages and cur_messages + coder.done_messages = [ + {"role": "user", "content": "Hello, can you help me with a Python problem?"}, + {"role": "assistant", "content": "Of course! I'd be happy to help. What's the problem you're facing?"}, + {"role": "user", "content": "I need to write a function that calculates the factorial of a number."}, + {"role": "assistant", "content": "Sure, I can help you with that. Here's a simple Python function to calculate the factorial of a number:"}, + ] + + coder.cur_messages = [ + {"role": "user", "content": "Can you optimize this function for large numbers?"}, + ] + + # Mock the necessary methods and attributes + coder.main_model.token_count = MagicMock(return_value=1000) + coder.main_model.info = { + "max_input_tokens": 4000, + "max_output_tokens": 1000, + } + coder.partial_response_content = "Here's an optimized version of the factorial function:" + coder.io.tool_error = MagicMock() + + # Call the method + coder.show_exhausted_error() + + # Check if tool_error was called with the expected message + coder.io.tool_error.assert_called() + error_message = coder.io.tool_error.call_args[0][0] + + # Assert that the error message contains the expected information + self.assertIn("Model gpt-3.5-turbo has hit a token limit!", error_message) + self.assertIn("Input tokens: ~1,000 of 4,000", error_message) + self.assertIn("Output tokens: ~1,000 of 1,000", error_message) + self.assertIn("Total tokens: ~2,000 of 4,000", error_message) if __name__ == "__main__": unittest.main()