feat(coder): Handle LLM response interruption

Introduce a mechanism to cleanly interrupt the LLM's streaming response, particularly useful for GUI or non-terminal interfaces.

- Add `AiderAbortException` to signal interruption from the IO layer.
- Implement `io.is_interrupted()` as a check within the streaming loop (`show_send_output_stream`).
- Update `send_message` to catch `AiderAbortException` and `KeyboardInterrupt`.
- Ensure partial response content is captured and marked as "(interrupted)" in the message history upon interruption.
- Skip post-processing steps (applying edits, running tests/lint, etc.) if the response was interrupted.
- Add tests for the new interruption handling logic.
This commit is contained in:
Derrick Hammer (aider) 2025-05-19 04:19:14 -04:00 committed by Derrick Hammer
parent 3caab85931
commit 40b75bc57f
4 changed files with 352 additions and 45 deletions

View file

@ -69,6 +69,10 @@ class MissingAPIKeyError(ValueError):
class FinishReasonLength(Exception):
pass
# Add this new exception
class AiderAbortException(Exception):
pass
def wrap_fence(name):
return f"<{name}>", f"</{name}>"
@ -864,6 +868,7 @@ class Coder:
self.test_outcome = None
self.shell_commands = []
self.message_cost = 0
self.partial_response_function_call = None # Initialize function call attribute
if self.repo:
self.commit_before_message.append(self.repo.get_head_commit_sha())
@ -1417,6 +1422,8 @@ class Coder:
# Notify IO that LLM processing is starting
self.io.llm_started()
self.partial_response_function_call = None # Initialize function call attribute here
self.cur_messages += [
dict(role="user", content=inp),
]
@ -1447,10 +1454,11 @@ class Coder:
self.usage_report = None
exhausted = False
interrupted = False
try:
interrupted = False # Flag to track if AiderAbortException was caught
try: # This try block handles retries and different exception types
while True:
try:
# This is where AiderAbortException can be raised by show_send_output_stream
yield from self.send(messages, functions=self.functions)
break
except litellm_ex.exceptions_tuple() as err:
@ -1482,8 +1490,14 @@ class Coder:
time.sleep(retry_delay)
continue
except KeyboardInterrupt:
interrupted = True
break
# This handles Ctrl+C from the terminal, which calls io.interrupt_input()
interrupted = True # Mark as interrupted
self.keyboard_interrupt() # Handle the Ctrl+C signal
break # Exit the retry loop
except AiderAbortException: # This is the exception raised by show_send_output_stream
# This handles abort signals from the IO layer (e.g., Connector)
interrupted = True # Mark as interrupted
break # Exit the retry loop
except FinishReasonLength:
# We hit the output limit!
if not self.main_model.info.get("supports_assistant_prefill"):
@ -1506,6 +1520,7 @@ class Coder:
self.event("message_send_exception", exception=str(err))
return
finally:
# This block always runs, even if an exception occurred or the loop broke (e.g., due to interrupt)
if self.mdstream:
self.live_incremental_response(True)
self.mdstream = None
@ -1513,34 +1528,36 @@ class Coder:
# Ensure any waiting spinner is stopped
self._stop_waiting_spinner()
# Capture the final partial content regardless of how it ended
self.partial_response_content = self.get_multi_response_content_in_progress(True)
self.remove_reasoning_content()
self.multi_response_content = ""
###
# print()
# print("=" * 20)
# dump(self.partial_response_content)
# Add the assistant's reply (even if partial/interrupted) to cur_messages
self.add_assistant_reply_to_cur_messages(interrupted=interrupted) # Pass interrupted status
self.io.tool_output()
# Check the interrupted flag *after* the finally block
if interrupted:
# If interrupted, we stop here and don't proceed with post-processing
self.io.tool_output("Response interrupted.")
return # Exit send_message
# If we reach here, the stream completed without interruption or context exhaustion
self.io.tool_output() # Add a newline after the assistant output
self.show_usage_report()
self.add_assistant_reply_to_cur_messages()
# The assistant reply was already added to cur_messages in the finally block
if exhausted:
if self.cur_messages and self.cur_messages[-1]["role"] == "user":
self.cur_messages += [
dict(
role="assistant",
content="FinishReasonLength exception: you sent too many tokens",
),
]
# This case should now be handled before the main post-processing
# but keeping the check here for safety/clarity.
self.show_exhausted_error()
self.num_exhausted_context_windows += 1
return
# Proceed with post-processing only if not interrupted
if self.partial_response_function_call:
args = self.parse_partial_args()
if args:
@ -1552,31 +1569,35 @@ class Coder:
else:
content = ""
if not interrupted:
add_rel_files_message = self.check_for_file_mentions(content)
if add_rel_files_message:
if self.reflected_message:
self.reflected_message += "\n\n" + add_rel_files_message
else:
self.reflected_message = add_rel_files_message
return
try:
if self.reply_completed():
return
except KeyboardInterrupt:
interrupted = True
if interrupted:
if self.cur_messages and self.cur_messages[-1]["role"] == "user":
self.cur_messages[-1]["content"] += "\n^C KeyboardInterrupt"
# Check for file mentions in the final content
add_rel_files_message = self.check_for_file_mentions(content)
if add_rel_files_message:
# If files were mentioned and added, set reflected_message
if self.reflected_message:
self.reflected_message += "\n\n" + add_rel_files_message
else:
self.cur_messages += [dict(role="user", content="^C KeyboardInterrupt")]
self.cur_messages += [
dict(role="assistant", content="I see that you interrupted my previous reply.")
]
self.reflected_message = add_rel_files_message
# Return here to trigger reflection in run_one
return
try:
# Check if the reply completed successfully (subclass specific)
if self.reply_completed():
return
except KeyboardInterrupt:
# Handle Ctrl+C during reply_completed (less likely but possible)
interrupted = True
self.keyboard_interrupt()
# Fall through to the interrupted handling below
# If interrupted during reply_completed or file mention check
if interrupted:
# The message history already has the partial content + "(interrupted)"
# No need to add another message here.
self.io.tool_output("Response interrupted.")
return # Exit send_message
# Apply edits only if not interrupted and reply completed successfully
edited = self.apply_updates()
if edited:
@ -1586,11 +1607,14 @@ class Coder:
if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"):
saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo
# Move messages back only after successful edits and commit
self.move_back_cur_messages(saved_message)
# Check for reflections again after applying edits
if self.reflected_message:
return
# Run lint/test only if edits were applied and not interrupted
if edited and self.auto_lint:
lint_errors = self.lint_edited(edited)
self.auto_commit(edited, context="Ran the linter")
@ -1601,13 +1625,16 @@ class Coder:
self.reflected_message = lint_errors
return
# Run shell commands only if not interrupted
shared_output = self.run_shell_commands()
if shared_output:
# Add shell command output to chat history
self.cur_messages += [
dict(role="user", content=shared_output),
dict(role="assistant", content="Ok"),
]
# Run tests only if edits were applied and not interrupted
if edited and self.auto_test:
test_errors = self.commands.cmd_test(self.test_cmd)
self.test_outcome = not test_errors
@ -1694,10 +1721,21 @@ class Coder:
"""Cleanup when the Coder object is destroyed."""
self.ok_to_warm_cache = False
def add_assistant_reply_to_cur_messages(self):
if self.partial_response_content:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
if self.partial_response_function_call:
def add_assistant_reply_to_cur_messages(self, interrupted=False):
# Add the partial or complete response content to the current messages
# This happens in the finally block, so it includes interrupted content
content_to_add = self.partial_response_content
if interrupted and content_to_add: # Use the passed interrupted flag
content_to_add += " (interrupted)" # Indicate interruption in history
if content_to_add:
self.cur_messages += [dict(role="assistant", content=content_to_add)]
# Only add function call if not interrupted, as it implies completion
if self.partial_response_function_call and not interrupted:
# Note: Function calls are not currently interrupted by the IO layer check
# as they don't stream content chunk by chunk in the same way.
# If function calls needed interruption, the LiteLLM call itself would need
# to be cancellable, which is provider/LiteLLM dependent.
self.cur_messages += [
dict(
role="assistant",
@ -1706,6 +1744,7 @@ class Coder:
)
]
def get_file_mentions(self, content, ignore_current=False):
words = set(word for word in content.split())
@ -1896,6 +1935,11 @@ class Coder:
received_content = False
for chunk in completion:
# Check for interruption at the start of processing each chunk
if self.io.is_interrupted():
self._stop_waiting_spinner() # Stop spinner on abort
raise AiderAbortException("Stream interrupted by user.")
if len(chunk.choices) == 0:
continue

View file

@ -514,6 +514,13 @@ class InputOutput:
self.interrupted = True
self.prompt_session.app.exit()
def is_interrupted(self):
"""
Check if the current operation should be interrupted.
Implementations should override this based on their environment.
"""
return False # Default implementation
def get_input(
self,
root,

View file

@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
import git
from aider.coders import Coder
from aider.coders.base_coder import FinishReasonLength, UnknownEditFormat
from aider.coders.base_coder import AiderAbortException, FinishReasonLength, UnknownEditFormat
from aider.dump import dump # noqa: F401
from aider.io import InputOutput
from aider.models import Model
@ -1433,6 +1433,257 @@ This command will print 'Hello, World!' to the console."""
# (because user rejected the changes)
mock_editor.run.assert_not_called()
@patch("aider.coders.base_coder.Coder.apply_updates")
@patch("aider.coders.base_coder.Coder.run_shell_commands")
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
@patch("aider.coders.base_coder.Coder.reply_completed")
@patch("aider.coders.base_coder.Coder.show_usage_report")
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
def test_send_message_handles_abort_exception(
self,
mock_stop_spinner,
mock_show_usage,
mock_reply_completed,
mock_check_mentions,
mock_run_shell,
mock_apply_updates,
):
"""Test that send_message handles AiderAbortException correctly."""
with GitTemporaryDirectory():
io = InputOutput(yes=True)
coder = Coder.create(self.GPT35, "diff", io=io)
# Mock the send method to raise AiderAbortException
def mock_send(*args, **kwargs):
# Simulate some partial content before abort
coder.partial_response_content = "Partial content received."
raise AiderAbortException("Test abort")
coder.send = MagicMock(side_effect=mock_send)
# Mock post-processing methods to ensure they are NOT called
mock_apply_updates.return_value = set()
mock_run_shell.return_value = ""
mock_check_mentions.return_value = None
mock_reply_completed.return_value = False
mock_show_usage.return_value = None
# Run send_message (it's a generator, so iterate through it)
list(coder.send_message("Test message"))
# Assert that the partial content was captured and added to history
self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"])
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
# Assert that post-processing methods were NOT called
mock_apply_updates.assert_not_called()
mock_run_shell.assert_not_called()
mock_check_mentions.assert_not_called()
mock_reply_completed.assert_not_called()
mock_show_usage.assert_not_called() # show_usage_report should be skipped on interruption
# Assert spinner was stopped
mock_stop_spinner.assert_called()
@patch("aider.coders.base_coder.Coder.apply_updates")
@patch("aider.coders.base_coder.Coder.run_shell_commands")
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
@patch("aider.coders.base_coder.Coder.reply_completed")
@patch("aider.coders.base_coder.Coder.show_usage_report")
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
@patch("aider.io.InputOutput.is_interrupted", return_value=True) # Simulate IO signal
def test_streaming_interruption_raises_exception(
self,
mock_is_interrupted,
mock_stop_spinner,
mock_show_usage,
mock_reply_completed,
mock_check_mentions,
mock_run_shell,
mock_apply_updates,
):
"""Test that streaming is interrupted when io.is_interrupted() is True."""
with GitTemporaryDirectory():
io = InputOutput(yes=True)
coder = Coder.create(self.GPT35, "diff", io=io, stream=True)
# Mock the underlying LiteLLM completion stream
mock_completion_stream = MagicMock()
# Simulate a few chunks before interruption
mock_completion_stream.__iter__.return_value = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk1"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk2"))]),
# The third iteration will check io.is_interrupted() and raise
]
# Mock the send method to return the mock stream
coder.send = MagicMock(return_value=mock_completion_stream)
# Mock post-processing methods to ensure they are NOT called
mock_apply_updates.return_value = set()
mock_run_shell.return_value = ""
mock_check_mentions.return_value = None
mock_reply_completed.return_value = False
mock_show_usage.return_value = None
# Run send_message (it's a generator, so iterate through it)
# Expect AiderAbortException to be raised by show_send_output_stream
# which is called by send, which is yielded from by send_message.
# The exception should be caught within send_message.
list(coder.send_message("Test streaming interruption"))
# Assert that the interrupted flag was set within send_message's exception handler
# We can't directly check the local 'interrupted' flag in send_message,
# but we can check that the io object's flag was used (which we mocked)
# and that post-processing was skipped.
self.assertTrue(mock_is_interrupted.called) # Verify the mock was checked
# Assert that the partial content was captured and added to history
# The content should include chunks received before the interrupt
self.assertIn("chunk1chunk2 (interrupted)", coder.cur_messages[-1]["content"])
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
# Assert that post-processing methods were NOT called
mock_apply_updates.assert_not_called()
mock_run_shell.assert_not_called()
mock_check_mentions.assert_not_called()
mock_reply_completed.assert_not_called()
mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check
# Assert spinner was stopped
mock_stop_spinner.assert_called()
@patch("aider.coders.base_coder.Coder.apply_updates")
@patch("aider.coders.base_coder.Coder.run_shell_commands")
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
@patch("aider.coders.base_coder.Coder.reply_completed")
@patch("aider.coders.base_coder.Coder.show_usage_report")
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
@patch("aider.coders.base_coder.Coder.keyboard_interrupt") # Mock the handler
def test_send_message_handles_keyboard_interrupt(
self,
mock_keyboard_interrupt,
mock_stop_spinner,
mock_show_usage,
mock_reply_completed,
mock_check_mentions,
mock_run_shell,
mock_apply_updates,
):
"""Test that send_message handles KeyboardInterrupt correctly."""
with GitTemporaryDirectory():
io = InputOutput(yes=True)
coder = Coder.create(self.GPT35, "diff", io=io)
# Mock the send method to raise KeyboardInterrupt
def mock_send(*args, **kwargs):
# Simulate some partial content before interrupt
coder.partial_response_content = "Partial content received."
raise KeyboardInterrupt("Test interrupt")
coder.send = MagicMock(side_effect=mock_send)
# Mock post-processing methods to ensure they are NOT called
mock_apply_updates.return_value = set()
mock_run_shell.return_value = ""
mock_check_mentions.return_value = None
mock_reply_completed.return_value = False
mock_show_usage.return_value = None
# Run send_message (it's a generator, so iterate through it)
list(coder.send_message("Test message"))
# Assert that keyboard_interrupt handler was called
mock_keyboard_interrupt.assert_called_once()
# Assert that the partial content was captured and added to history
self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"])
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
# Assert that post-processing methods were NOT called
mock_apply_updates.assert_not_called()
mock_run_shell.assert_not_called()
mock_check_mentions.assert_not_called()
mock_reply_completed.assert_not_called()
mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check
# Assert spinner was stopped
mock_stop_spinner.assert_called()
def test_add_assistant_reply_to_cur_messages_interrupted(self):
"""Test that add_assistant_reply_to_cur_messages marks content as interrupted."""
with GitTemporaryDirectory():
io = InputOutput(yes=True)
coder = Coder.create(self.GPT35, "diff", io=io)
# Simulate partial content and interrupted state
coder.partial_response_content = "This is partial content."
coder.partial_response_function_call = None # No function call
# Add the reply, simulating interruption
coder.add_assistant_reply_to_cur_messages(interrupted=True)
# Check the last message in cur_messages
self.assertEqual(len(coder.cur_messages), 1)
last_msg = coder.cur_messages[-1]
self.assertEqual(last_msg["role"], "assistant")
self.assertEqual(last_msg["content"], "This is partial content. (interrupted)")
self.assertNotIn("function_call", last_msg)
# Simulate partial content with function call and interrupted state
coder.cur_messages = [] # Reset messages
coder.partial_response_content = "This is partial content before func call."
coder.partial_response_function_call = {"name": "test_func", "arguments": "{}"}
# Add the reply, simulating interruption
coder.add_assistant_reply_to_cur_messages(interrupted=True)
# Check the last message in cur_messages
self.assertEqual(len(coder.cur_messages), 1)
last_msg = coder.cur_messages[-1]
self.assertEqual(last_msg["role"], "assistant")
# Content should be marked as interrupted
self.assertEqual(last_msg["content"], "This is partial content before func call. (interrupted)")
# Function call should NOT be added if interrupted
self.assertNotIn("function_call", last_msg)
def test_add_assistant_reply_to_cur_messages_not_interrupted(self):
"""Test that add_assistant_reply_to_cur_messages works normally when not interrupted."""
with GitTemporaryDirectory():
io = InputOutput(yes=True)
coder = Coder.create(self.GPT35, "diff", io=io)
# Simulate complete content and no interruption
coder.partial_response_content = "This is the final content."
coder.partial_response_function_call = None # No function call
# Add the reply, not interrupted
coder.add_assistant_reply_to_cur_messages(interrupted=False)
# Check the last message in cur_messages
self.assertEqual(len(coder.cur_messages), 1)
last_msg = coder.cur_messages[-1]
self.assertEqual(last_msg["role"], "assistant")
self.assertEqual(last_msg["content"], "This is the final content.")
self.assertNotIn("function_call", last_msg)
# Simulate complete content with function call and no interruption
coder.cur_messages = [] # Reset messages
coder.partial_response_content = "This is content before func call."
coder.partial_response_function_call = {"name": "test_func", "arguments": '{"arg": "val"}'}
# Add the reply, not interrupted
coder.add_assistant_reply_to_cur_messages(interrupted=False)
# Check the last message in cur_messages
self.assertEqual(len(coder.cur_messages), 1)
last_msg = coder.cur_messages[-1]
self.assertEqual(last_msg["role"], "assistant")
self.assertEqual(last_msg["content"], "This is content before func call.")
# Function call SHOULD be added if not interrupted
self.assertIn("function_call", last_msg)
self.assertEqual(last_msg["function_call"], {"name": "test_func", "arguments": '{"arg": "val"}'})
if __name__ == "__main__":
unittest.main()

View file

@ -341,6 +341,11 @@ class TestInputOutput(unittest.TestCase):
self.assertEqual(mock_input.call_count, 2)
self.assertNotIn(("Do you want to proceed?", None), io.never_prompts)
def test_is_interrupted_default(self):
"""Test that is_interrupted returns False by default."""
io = InputOutput()
self.assertFalse(io.is_interrupted())
class TestInputOutputMultilineMode(unittest.TestCase):
def setUp(self):