mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-01 10:14:59 +00:00
feat(coder): Handle LLM response interruption
Introduce a mechanism to cleanly interrupt the LLM's streaming response, particularly useful for GUI or non-terminal interfaces. - Add `AiderAbortException` to signal interruption from the IO layer. - Implement `io.is_interrupted()` as a check within the streaming loop (`show_send_output_stream`). - Update `send_message` to catch `AiderAbortException` and `KeyboardInterrupt`. - Ensure partial response content is captured and marked as "(interrupted)" in the message history upon interruption. - Skip post-processing steps (applying edits, running tests/lint, etc.) if the response was interrupted. - Add tests for the new interruption handling logic.
This commit is contained in:
parent
3caab85931
commit
40b75bc57f
4 changed files with 352 additions and 45 deletions
|
@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
|
|||
import git
|
||||
|
||||
from aider.coders import Coder
|
||||
from aider.coders.base_coder import FinishReasonLength, UnknownEditFormat
|
||||
from aider.coders.base_coder import AiderAbortException, FinishReasonLength, UnknownEditFormat
|
||||
from aider.dump import dump # noqa: F401
|
||||
from aider.io import InputOutput
|
||||
from aider.models import Model
|
||||
|
@ -1433,6 +1433,257 @@ This command will print 'Hello, World!' to the console."""
|
|||
# (because user rejected the changes)
|
||||
mock_editor.run.assert_not_called()
|
||||
|
||||
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||
def test_send_message_handles_abort_exception(
|
||||
self,
|
||||
mock_stop_spinner,
|
||||
mock_show_usage,
|
||||
mock_reply_completed,
|
||||
mock_check_mentions,
|
||||
mock_run_shell,
|
||||
mock_apply_updates,
|
||||
):
|
||||
"""Test that send_message handles AiderAbortException correctly."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Mock the send method to raise AiderAbortException
|
||||
def mock_send(*args, **kwargs):
|
||||
# Simulate some partial content before abort
|
||||
coder.partial_response_content = "Partial content received."
|
||||
raise AiderAbortException("Test abort")
|
||||
|
||||
coder.send = MagicMock(side_effect=mock_send)
|
||||
|
||||
# Mock post-processing methods to ensure they are NOT called
|
||||
mock_apply_updates.return_value = set()
|
||||
mock_run_shell.return_value = ""
|
||||
mock_check_mentions.return_value = None
|
||||
mock_reply_completed.return_value = False
|
||||
mock_show_usage.return_value = None
|
||||
|
||||
# Run send_message (it's a generator, so iterate through it)
|
||||
list(coder.send_message("Test message"))
|
||||
|
||||
# Assert that the partial content was captured and added to history
|
||||
self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"])
|
||||
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||
|
||||
# Assert that post-processing methods were NOT called
|
||||
mock_apply_updates.assert_not_called()
|
||||
mock_run_shell.assert_not_called()
|
||||
mock_check_mentions.assert_not_called()
|
||||
mock_reply_completed.assert_not_called()
|
||||
mock_show_usage.assert_not_called() # show_usage_report should be skipped on interruption
|
||||
|
||||
# Assert spinner was stopped
|
||||
mock_stop_spinner.assert_called()
|
||||
|
||||
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||
@patch("aider.io.InputOutput.is_interrupted", return_value=True) # Simulate IO signal
|
||||
def test_streaming_interruption_raises_exception(
|
||||
self,
|
||||
mock_is_interrupted,
|
||||
mock_stop_spinner,
|
||||
mock_show_usage,
|
||||
mock_reply_completed,
|
||||
mock_check_mentions,
|
||||
mock_run_shell,
|
||||
mock_apply_updates,
|
||||
):
|
||||
"""Test that streaming is interrupted when io.is_interrupted() is True."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io, stream=True)
|
||||
|
||||
# Mock the underlying LiteLLM completion stream
|
||||
mock_completion_stream = MagicMock()
|
||||
# Simulate a few chunks before interruption
|
||||
mock_completion_stream.__iter__.return_value = [
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk1"))]),
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk2"))]),
|
||||
# The third iteration will check io.is_interrupted() and raise
|
||||
]
|
||||
|
||||
# Mock the send method to return the mock stream
|
||||
coder.send = MagicMock(return_value=mock_completion_stream)
|
||||
|
||||
# Mock post-processing methods to ensure they are NOT called
|
||||
mock_apply_updates.return_value = set()
|
||||
mock_run_shell.return_value = ""
|
||||
mock_check_mentions.return_value = None
|
||||
mock_reply_completed.return_value = False
|
||||
mock_show_usage.return_value = None
|
||||
|
||||
# Run send_message (it's a generator, so iterate through it)
|
||||
# Expect AiderAbortException to be raised by show_send_output_stream
|
||||
# which is called by send, which is yielded from by send_message.
|
||||
# The exception should be caught within send_message.
|
||||
list(coder.send_message("Test streaming interruption"))
|
||||
|
||||
# Assert that the interrupted flag was set within send_message's exception handler
|
||||
# We can't directly check the local 'interrupted' flag in send_message,
|
||||
# but we can check that the io object's flag was used (which we mocked)
|
||||
# and that post-processing was skipped.
|
||||
self.assertTrue(mock_is_interrupted.called) # Verify the mock was checked
|
||||
|
||||
# Assert that the partial content was captured and added to history
|
||||
# The content should include chunks received before the interrupt
|
||||
self.assertIn("chunk1chunk2 (interrupted)", coder.cur_messages[-1]["content"])
|
||||
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||
|
||||
# Assert that post-processing methods were NOT called
|
||||
mock_apply_updates.assert_not_called()
|
||||
mock_run_shell.assert_not_called()
|
||||
mock_check_mentions.assert_not_called()
|
||||
mock_reply_completed.assert_not_called()
|
||||
mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check
|
||||
|
||||
# Assert spinner was stopped
|
||||
mock_stop_spinner.assert_called()
|
||||
|
||||
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||
@patch("aider.coders.base_coder.Coder.keyboard_interrupt") # Mock the handler
|
||||
def test_send_message_handles_keyboard_interrupt(
|
||||
self,
|
||||
mock_keyboard_interrupt,
|
||||
mock_stop_spinner,
|
||||
mock_show_usage,
|
||||
mock_reply_completed,
|
||||
mock_check_mentions,
|
||||
mock_run_shell,
|
||||
mock_apply_updates,
|
||||
):
|
||||
"""Test that send_message handles KeyboardInterrupt correctly."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Mock the send method to raise KeyboardInterrupt
|
||||
def mock_send(*args, **kwargs):
|
||||
# Simulate some partial content before interrupt
|
||||
coder.partial_response_content = "Partial content received."
|
||||
raise KeyboardInterrupt("Test interrupt")
|
||||
|
||||
coder.send = MagicMock(side_effect=mock_send)
|
||||
|
||||
# Mock post-processing methods to ensure they are NOT called
|
||||
mock_apply_updates.return_value = set()
|
||||
mock_run_shell.return_value = ""
|
||||
mock_check_mentions.return_value = None
|
||||
mock_reply_completed.return_value = False
|
||||
mock_show_usage.return_value = None
|
||||
|
||||
# Run send_message (it's a generator, so iterate through it)
|
||||
list(coder.send_message("Test message"))
|
||||
|
||||
# Assert that keyboard_interrupt handler was called
|
||||
mock_keyboard_interrupt.assert_called_once()
|
||||
|
||||
# Assert that the partial content was captured and added to history
|
||||
self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"])
|
||||
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||
|
||||
# Assert that post-processing methods were NOT called
|
||||
mock_apply_updates.assert_not_called()
|
||||
mock_run_shell.assert_not_called()
|
||||
mock_check_mentions.assert_not_called()
|
||||
mock_reply_completed.assert_not_called()
|
||||
mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check
|
||||
|
||||
# Assert spinner was stopped
|
||||
mock_stop_spinner.assert_called()
|
||||
|
||||
def test_add_assistant_reply_to_cur_messages_interrupted(self):
|
||||
"""Test that add_assistant_reply_to_cur_messages marks content as interrupted."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Simulate partial content and interrupted state
|
||||
coder.partial_response_content = "This is partial content."
|
||||
coder.partial_response_function_call = None # No function call
|
||||
|
||||
# Add the reply, simulating interruption
|
||||
coder.add_assistant_reply_to_cur_messages(interrupted=True)
|
||||
|
||||
# Check the last message in cur_messages
|
||||
self.assertEqual(len(coder.cur_messages), 1)
|
||||
last_msg = coder.cur_messages[-1]
|
||||
self.assertEqual(last_msg["role"], "assistant")
|
||||
self.assertEqual(last_msg["content"], "This is partial content. (interrupted)")
|
||||
self.assertNotIn("function_call", last_msg)
|
||||
|
||||
# Simulate partial content with function call and interrupted state
|
||||
coder.cur_messages = [] # Reset messages
|
||||
coder.partial_response_content = "This is partial content before func call."
|
||||
coder.partial_response_function_call = {"name": "test_func", "arguments": "{}"}
|
||||
|
||||
# Add the reply, simulating interruption
|
||||
coder.add_assistant_reply_to_cur_messages(interrupted=True)
|
||||
|
||||
# Check the last message in cur_messages
|
||||
self.assertEqual(len(coder.cur_messages), 1)
|
||||
last_msg = coder.cur_messages[-1]
|
||||
self.assertEqual(last_msg["role"], "assistant")
|
||||
# Content should be marked as interrupted
|
||||
self.assertEqual(last_msg["content"], "This is partial content before func call. (interrupted)")
|
||||
# Function call should NOT be added if interrupted
|
||||
self.assertNotIn("function_call", last_msg)
|
||||
|
||||
def test_add_assistant_reply_to_cur_messages_not_interrupted(self):
|
||||
"""Test that add_assistant_reply_to_cur_messages works normally when not interrupted."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Simulate complete content and no interruption
|
||||
coder.partial_response_content = "This is the final content."
|
||||
coder.partial_response_function_call = None # No function call
|
||||
|
||||
# Add the reply, not interrupted
|
||||
coder.add_assistant_reply_to_cur_messages(interrupted=False)
|
||||
|
||||
# Check the last message in cur_messages
|
||||
self.assertEqual(len(coder.cur_messages), 1)
|
||||
last_msg = coder.cur_messages[-1]
|
||||
self.assertEqual(last_msg["role"], "assistant")
|
||||
self.assertEqual(last_msg["content"], "This is the final content.")
|
||||
self.assertNotIn("function_call", last_msg)
|
||||
|
||||
# Simulate complete content with function call and no interruption
|
||||
coder.cur_messages = [] # Reset messages
|
||||
coder.partial_response_content = "This is content before func call."
|
||||
coder.partial_response_function_call = {"name": "test_func", "arguments": '{"arg": "val"}'}
|
||||
|
||||
# Add the reply, not interrupted
|
||||
coder.add_assistant_reply_to_cur_messages(interrupted=False)
|
||||
|
||||
# Check the last message in cur_messages
|
||||
self.assertEqual(len(coder.cur_messages), 1)
|
||||
last_msg = coder.cur_messages[-1]
|
||||
self.assertEqual(last_msg["role"], "assistant")
|
||||
self.assertEqual(last_msg["content"], "This is content before func call.")
|
||||
# Function call SHOULD be added if not interrupted
|
||||
self.assertIn("function_call", last_msg)
|
||||
self.assertEqual(last_msg["function_call"], {"name": "test_func", "arguments": '{"arg": "val"}'})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue