diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 5dbe03cf9..ac3b50bce 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -69,6 +69,10 @@ class MissingAPIKeyError(ValueError): class FinishReasonLength(Exception): pass +# Add this new exception +class AiderAbortException(Exception): + pass + def wrap_fence(name): return f"<{name}>", f"" @@ -864,6 +868,7 @@ class Coder: self.test_outcome = None self.shell_commands = [] self.message_cost = 0 + self.partial_response_function_call = None # Initialize function call attribute if self.repo: self.commit_before_message.append(self.repo.get_head_commit_sha()) @@ -1417,6 +1422,8 @@ class Coder: # Notify IO that LLM processing is starting self.io.llm_started() + self.partial_response_function_call = None # Initialize function call attribute here + self.cur_messages += [ dict(role="user", content=inp), ] @@ -1447,10 +1454,11 @@ class Coder: self.usage_report = None exhausted = False - interrupted = False - try: + interrupted = False # Flag to track if AiderAbortException was caught + try: # This try block handles retries and different exception types while True: try: + # This is where AiderAbortException can be raised by show_send_output_stream yield from self.send(messages, functions=self.functions) break except litellm_ex.exceptions_tuple() as err: @@ -1482,8 +1490,14 @@ class Coder: time.sleep(retry_delay) continue except KeyboardInterrupt: - interrupted = True - break + # This handles Ctrl+C from the terminal, which calls io.interrupt_input() + interrupted = True # Mark as interrupted + self.keyboard_interrupt() # Handle the Ctrl+C signal + break # Exit the retry loop + except AiderAbortException: # This is the exception raised by show_send_output_stream + # This handles abort signals from the IO layer (e.g., Connector) + interrupted = True # Mark as interrupted + break # Exit the retry loop except FinishReasonLength: # We hit the output limit! if not self.main_model.info.get("supports_assistant_prefill"): @@ -1506,6 +1520,7 @@ class Coder: self.event("message_send_exception", exception=str(err)) return finally: + # This block always runs, even if an exception occurred or the loop broke (e.g., due to interrupt) if self.mdstream: self.live_incremental_response(True) self.mdstream = None @@ -1513,34 +1528,36 @@ class Coder: # Ensure any waiting spinner is stopped self._stop_waiting_spinner() + # Capture the final partial content regardless of how it ended self.partial_response_content = self.get_multi_response_content_in_progress(True) self.remove_reasoning_content() self.multi_response_content = "" - ### - # print() - # print("=" * 20) - # dump(self.partial_response_content) + # Add the assistant's reply (even if partial/interrupted) to cur_messages + self.add_assistant_reply_to_cur_messages(interrupted=interrupted) # Pass interrupted status - self.io.tool_output() + # Check the interrupted flag *after* the finally block + if interrupted: + # If interrupted, we stop here and don't proceed with post-processing + self.io.tool_output("Response interrupted.") + return # Exit send_message + + # If we reach here, the stream completed without interruption or context exhaustion + + self.io.tool_output() # Add a newline after the assistant output self.show_usage_report() - self.add_assistant_reply_to_cur_messages() + # The assistant reply was already added to cur_messages in the finally block if exhausted: - if self.cur_messages and self.cur_messages[-1]["role"] == "user": - self.cur_messages += [ - dict( - role="assistant", - content="FinishReasonLength exception: you sent too many tokens", - ), - ] - + # This case should now be handled before the main post-processing + # but keeping the check here for safety/clarity. self.show_exhausted_error() self.num_exhausted_context_windows += 1 return + # Proceed with post-processing only if not interrupted if self.partial_response_function_call: args = self.parse_partial_args() if args: @@ -1552,31 +1569,35 @@ class Coder: else: content = "" - if not interrupted: - add_rel_files_message = self.check_for_file_mentions(content) - if add_rel_files_message: - if self.reflected_message: - self.reflected_message += "\n\n" + add_rel_files_message - else: - self.reflected_message = add_rel_files_message - return - - try: - if self.reply_completed(): - return - except KeyboardInterrupt: - interrupted = True - - if interrupted: - if self.cur_messages and self.cur_messages[-1]["role"] == "user": - self.cur_messages[-1]["content"] += "\n^C KeyboardInterrupt" + # Check for file mentions in the final content + add_rel_files_message = self.check_for_file_mentions(content) + if add_rel_files_message: + # If files were mentioned and added, set reflected_message + if self.reflected_message: + self.reflected_message += "\n\n" + add_rel_files_message else: - self.cur_messages += [dict(role="user", content="^C KeyboardInterrupt")] - self.cur_messages += [ - dict(role="assistant", content="I see that you interrupted my previous reply.") - ] + self.reflected_message = add_rel_files_message + # Return here to trigger reflection in run_one return + try: + # Check if the reply completed successfully (subclass specific) + if self.reply_completed(): + return + except KeyboardInterrupt: + # Handle Ctrl+C during reply_completed (less likely but possible) + interrupted = True + self.keyboard_interrupt() + # Fall through to the interrupted handling below + + # If interrupted during reply_completed or file mention check + if interrupted: + # The message history already has the partial content + "(interrupted)" + # No need to add another message here. + self.io.tool_output("Response interrupted.") + return # Exit send_message + + # Apply edits only if not interrupted and reply completed successfully edited = self.apply_updates() if edited: @@ -1586,11 +1607,14 @@ class Coder: if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"): saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo + # Move messages back only after successful edits and commit self.move_back_cur_messages(saved_message) + # Check for reflections again after applying edits if self.reflected_message: return + # Run lint/test only if edits were applied and not interrupted if edited and self.auto_lint: lint_errors = self.lint_edited(edited) self.auto_commit(edited, context="Ran the linter") @@ -1601,13 +1625,16 @@ class Coder: self.reflected_message = lint_errors return + # Run shell commands only if not interrupted shared_output = self.run_shell_commands() if shared_output: + # Add shell command output to chat history self.cur_messages += [ dict(role="user", content=shared_output), dict(role="assistant", content="Ok"), ] + # Run tests only if edits were applied and not interrupted if edited and self.auto_test: test_errors = self.commands.cmd_test(self.test_cmd) self.test_outcome = not test_errors @@ -1694,10 +1721,21 @@ class Coder: """Cleanup when the Coder object is destroyed.""" self.ok_to_warm_cache = False - def add_assistant_reply_to_cur_messages(self): - if self.partial_response_content: - self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] - if self.partial_response_function_call: + def add_assistant_reply_to_cur_messages(self, interrupted=False): + # Add the partial or complete response content to the current messages + # This happens in the finally block, so it includes interrupted content + content_to_add = self.partial_response_content + if interrupted and content_to_add: # Use the passed interrupted flag + content_to_add += " (interrupted)" # Indicate interruption in history + + if content_to_add: + self.cur_messages += [dict(role="assistant", content=content_to_add)] + # Only add function call if not interrupted, as it implies completion + if self.partial_response_function_call and not interrupted: + # Note: Function calls are not currently interrupted by the IO layer check + # as they don't stream content chunk by chunk in the same way. + # If function calls needed interruption, the LiteLLM call itself would need + # to be cancellable, which is provider/LiteLLM dependent. self.cur_messages += [ dict( role="assistant", @@ -1706,6 +1744,7 @@ class Coder: ) ] + def get_file_mentions(self, content, ignore_current=False): words = set(word for word in content.split()) @@ -1896,6 +1935,11 @@ class Coder: received_content = False for chunk in completion: + # Check for interruption at the start of processing each chunk + if self.io.is_interrupted(): + self._stop_waiting_spinner() # Stop spinner on abort + raise AiderAbortException("Stream interrupted by user.") + if len(chunk.choices) == 0: continue diff --git a/aider/io.py b/aider/io.py index f28a1c86d..0161c58a4 100644 --- a/aider/io.py +++ b/aider/io.py @@ -514,6 +514,13 @@ class InputOutput: self.interrupted = True self.prompt_session.app.exit() + def is_interrupted(self): + """ + Check if the current operation should be interrupted. + Implementations should override this based on their environment. + """ + return False # Default implementation + def get_input( self, root, diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index 9dd0ce37a..fe0faf0fc 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch import git from aider.coders import Coder -from aider.coders.base_coder import FinishReasonLength, UnknownEditFormat +from aider.coders.base_coder import AiderAbortException, FinishReasonLength, UnknownEditFormat from aider.dump import dump # noqa: F401 from aider.io import InputOutput from aider.models import Model @@ -1433,6 +1433,257 @@ This command will print 'Hello, World!' to the console.""" # (because user rejected the changes) mock_editor.run.assert_not_called() + @patch("aider.coders.base_coder.Coder.apply_updates") + @patch("aider.coders.base_coder.Coder.run_shell_commands") + @patch("aider.coders.base_coder.Coder.check_for_file_mentions") + @patch("aider.coders.base_coder.Coder.reply_completed") + @patch("aider.coders.base_coder.Coder.show_usage_report") + @patch("aider.coders.base_coder.Coder._stop_waiting_spinner") + def test_send_message_handles_abort_exception( + self, + mock_stop_spinner, + mock_show_usage, + mock_reply_completed, + mock_check_mentions, + mock_run_shell, + mock_apply_updates, + ): + """Test that send_message handles AiderAbortException correctly.""" + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io) + + # Mock the send method to raise AiderAbortException + def mock_send(*args, **kwargs): + # Simulate some partial content before abort + coder.partial_response_content = "Partial content received." + raise AiderAbortException("Test abort") + + coder.send = MagicMock(side_effect=mock_send) + + # Mock post-processing methods to ensure they are NOT called + mock_apply_updates.return_value = set() + mock_run_shell.return_value = "" + mock_check_mentions.return_value = None + mock_reply_completed.return_value = False + mock_show_usage.return_value = None + + # Run send_message (it's a generator, so iterate through it) + list(coder.send_message("Test message")) + + # Assert that the partial content was captured and added to history + self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"]) + self.assertEqual(coder.cur_messages[-1]["role"], "assistant") + + # Assert that post-processing methods were NOT called + mock_apply_updates.assert_not_called() + mock_run_shell.assert_not_called() + mock_check_mentions.assert_not_called() + mock_reply_completed.assert_not_called() + mock_show_usage.assert_not_called() # show_usage_report should be skipped on interruption + + # Assert spinner was stopped + mock_stop_spinner.assert_called() + + @patch("aider.coders.base_coder.Coder.apply_updates") + @patch("aider.coders.base_coder.Coder.run_shell_commands") + @patch("aider.coders.base_coder.Coder.check_for_file_mentions") + @patch("aider.coders.base_coder.Coder.reply_completed") + @patch("aider.coders.base_coder.Coder.show_usage_report") + @patch("aider.coders.base_coder.Coder._stop_waiting_spinner") + @patch("aider.io.InputOutput.is_interrupted", return_value=True) # Simulate IO signal + def test_streaming_interruption_raises_exception( + self, + mock_is_interrupted, + mock_stop_spinner, + mock_show_usage, + mock_reply_completed, + mock_check_mentions, + mock_run_shell, + mock_apply_updates, + ): + """Test that streaming is interrupted when io.is_interrupted() is True.""" + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io, stream=True) + + # Mock the underlying LiteLLM completion stream + mock_completion_stream = MagicMock() + # Simulate a few chunks before interruption + mock_completion_stream.__iter__.return_value = [ + MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk1"))]), + MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk2"))]), + # The third iteration will check io.is_interrupted() and raise + ] + + # Mock the send method to return the mock stream + coder.send = MagicMock(return_value=mock_completion_stream) + + # Mock post-processing methods to ensure they are NOT called + mock_apply_updates.return_value = set() + mock_run_shell.return_value = "" + mock_check_mentions.return_value = None + mock_reply_completed.return_value = False + mock_show_usage.return_value = None + + # Run send_message (it's a generator, so iterate through it) + # Expect AiderAbortException to be raised by show_send_output_stream + # which is called by send, which is yielded from by send_message. + # The exception should be caught within send_message. + list(coder.send_message("Test streaming interruption")) + + # Assert that the interrupted flag was set within send_message's exception handler + # We can't directly check the local 'interrupted' flag in send_message, + # but we can check that the io object's flag was used (which we mocked) + # and that post-processing was skipped. + self.assertTrue(mock_is_interrupted.called) # Verify the mock was checked + + # Assert that the partial content was captured and added to history + # The content should include chunks received before the interrupt + self.assertIn("chunk1chunk2 (interrupted)", coder.cur_messages[-1]["content"]) + self.assertEqual(coder.cur_messages[-1]["role"], "assistant") + + # Assert that post-processing methods were NOT called + mock_apply_updates.assert_not_called() + mock_run_shell.assert_not_called() + mock_check_mentions.assert_not_called() + mock_reply_completed.assert_not_called() + mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check + + # Assert spinner was stopped + mock_stop_spinner.assert_called() + + @patch("aider.coders.base_coder.Coder.apply_updates") + @patch("aider.coders.base_coder.Coder.run_shell_commands") + @patch("aider.coders.base_coder.Coder.check_for_file_mentions") + @patch("aider.coders.base_coder.Coder.reply_completed") + @patch("aider.coders.base_coder.Coder.show_usage_report") + @patch("aider.coders.base_coder.Coder._stop_waiting_spinner") + @patch("aider.coders.base_coder.Coder.keyboard_interrupt") # Mock the handler + def test_send_message_handles_keyboard_interrupt( + self, + mock_keyboard_interrupt, + mock_stop_spinner, + mock_show_usage, + mock_reply_completed, + mock_check_mentions, + mock_run_shell, + mock_apply_updates, + ): + """Test that send_message handles KeyboardInterrupt correctly.""" + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io) + + # Mock the send method to raise KeyboardInterrupt + def mock_send(*args, **kwargs): + # Simulate some partial content before interrupt + coder.partial_response_content = "Partial content received." + raise KeyboardInterrupt("Test interrupt") + + coder.send = MagicMock(side_effect=mock_send) + + # Mock post-processing methods to ensure they are NOT called + mock_apply_updates.return_value = set() + mock_run_shell.return_value = "" + mock_check_mentions.return_value = None + mock_reply_completed.return_value = False + mock_show_usage.return_value = None + + # Run send_message (it's a generator, so iterate through it) + list(coder.send_message("Test message")) + + # Assert that keyboard_interrupt handler was called + mock_keyboard_interrupt.assert_called_once() + + # Assert that the partial content was captured and added to history + self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"]) + self.assertEqual(coder.cur_messages[-1]["role"], "assistant") + + # Assert that post-processing methods were NOT called + mock_apply_updates.assert_not_called() + mock_run_shell.assert_not_called() + mock_check_mentions.assert_not_called() + mock_reply_completed.assert_not_called() + mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check + + # Assert spinner was stopped + mock_stop_spinner.assert_called() + + def test_add_assistant_reply_to_cur_messages_interrupted(self): + """Test that add_assistant_reply_to_cur_messages marks content as interrupted.""" + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io) + + # Simulate partial content and interrupted state + coder.partial_response_content = "This is partial content." + coder.partial_response_function_call = None # No function call + + # Add the reply, simulating interruption + coder.add_assistant_reply_to_cur_messages(interrupted=True) + + # Check the last message in cur_messages + self.assertEqual(len(coder.cur_messages), 1) + last_msg = coder.cur_messages[-1] + self.assertEqual(last_msg["role"], "assistant") + self.assertEqual(last_msg["content"], "This is partial content. (interrupted)") + self.assertNotIn("function_call", last_msg) + + # Simulate partial content with function call and interrupted state + coder.cur_messages = [] # Reset messages + coder.partial_response_content = "This is partial content before func call." + coder.partial_response_function_call = {"name": "test_func", "arguments": "{}"} + + # Add the reply, simulating interruption + coder.add_assistant_reply_to_cur_messages(interrupted=True) + + # Check the last message in cur_messages + self.assertEqual(len(coder.cur_messages), 1) + last_msg = coder.cur_messages[-1] + self.assertEqual(last_msg["role"], "assistant") + # Content should be marked as interrupted + self.assertEqual(last_msg["content"], "This is partial content before func call. (interrupted)") + # Function call should NOT be added if interrupted + self.assertNotIn("function_call", last_msg) + + def test_add_assistant_reply_to_cur_messages_not_interrupted(self): + """Test that add_assistant_reply_to_cur_messages works normally when not interrupted.""" + with GitTemporaryDirectory(): + io = InputOutput(yes=True) + coder = Coder.create(self.GPT35, "diff", io=io) + + # Simulate complete content and no interruption + coder.partial_response_content = "This is the final content." + coder.partial_response_function_call = None # No function call + + # Add the reply, not interrupted + coder.add_assistant_reply_to_cur_messages(interrupted=False) + + # Check the last message in cur_messages + self.assertEqual(len(coder.cur_messages), 1) + last_msg = coder.cur_messages[-1] + self.assertEqual(last_msg["role"], "assistant") + self.assertEqual(last_msg["content"], "This is the final content.") + self.assertNotIn("function_call", last_msg) + + # Simulate complete content with function call and no interruption + coder.cur_messages = [] # Reset messages + coder.partial_response_content = "This is content before func call." + coder.partial_response_function_call = {"name": "test_func", "arguments": '{"arg": "val"}'} + + # Add the reply, not interrupted + coder.add_assistant_reply_to_cur_messages(interrupted=False) + + # Check the last message in cur_messages + self.assertEqual(len(coder.cur_messages), 1) + last_msg = coder.cur_messages[-1] + self.assertEqual(last_msg["role"], "assistant") + self.assertEqual(last_msg["content"], "This is content before func call.") + # Function call SHOULD be added if not interrupted + self.assertIn("function_call", last_msg) + self.assertEqual(last_msg["function_call"], {"name": "test_func", "arguments": '{"arg": "val"}'}) + if __name__ == "__main__": unittest.main() diff --git a/tests/basic/test_io.py b/tests/basic/test_io.py index 270a3c247..54f84ecd8 100644 --- a/tests/basic/test_io.py +++ b/tests/basic/test_io.py @@ -341,6 +341,11 @@ class TestInputOutput(unittest.TestCase): self.assertEqual(mock_input.call_count, 2) self.assertNotIn(("Do you want to proceed?", None), io.never_prompts) + def test_is_interrupted_default(self): + """Test that is_interrupted returns False by default.""" + io = InputOutput() + self.assertFalse(io.is_interrupted()) + class TestInputOutputMultilineMode(unittest.TestCase): def setUp(self):