mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-28 16:25:00 +00:00
feat(coder): Handle LLM response interruption
Introduce a mechanism to cleanly interrupt the LLM's streaming response, particularly useful for GUI or non-terminal interfaces. - Add `AiderAbortException` to signal interruption from the IO layer. - Implement `io.is_interrupted()` as a check within the streaming loop (`show_send_output_stream`). - Update `send_message` to catch `AiderAbortException` and `KeyboardInterrupt`. - Ensure partial response content is captured and marked as "(interrupted)" in the message history upon interruption. - Skip post-processing steps (applying edits, running tests/lint, etc.) if the response was interrupted. - Add tests for the new interruption handling logic.
This commit is contained in:
parent
3caab85931
commit
40b75bc57f
4 changed files with 352 additions and 45 deletions
|
@ -69,6 +69,10 @@ class MissingAPIKeyError(ValueError):
|
|||
class FinishReasonLength(Exception):
|
||||
pass
|
||||
|
||||
# Add this new exception
|
||||
class AiderAbortException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def wrap_fence(name):
|
||||
return f"<{name}>", f"</{name}>"
|
||||
|
@ -864,6 +868,7 @@ class Coder:
|
|||
self.test_outcome = None
|
||||
self.shell_commands = []
|
||||
self.message_cost = 0
|
||||
self.partial_response_function_call = None # Initialize function call attribute
|
||||
|
||||
if self.repo:
|
||||
self.commit_before_message.append(self.repo.get_head_commit_sha())
|
||||
|
@ -1417,6 +1422,8 @@ class Coder:
|
|||
# Notify IO that LLM processing is starting
|
||||
self.io.llm_started()
|
||||
|
||||
self.partial_response_function_call = None # Initialize function call attribute here
|
||||
|
||||
self.cur_messages += [
|
||||
dict(role="user", content=inp),
|
||||
]
|
||||
|
@ -1447,10 +1454,11 @@ class Coder:
|
|||
|
||||
self.usage_report = None
|
||||
exhausted = False
|
||||
interrupted = False
|
||||
try:
|
||||
interrupted = False # Flag to track if AiderAbortException was caught
|
||||
try: # This try block handles retries and different exception types
|
||||
while True:
|
||||
try:
|
||||
# This is where AiderAbortException can be raised by show_send_output_stream
|
||||
yield from self.send(messages, functions=self.functions)
|
||||
break
|
||||
except litellm_ex.exceptions_tuple() as err:
|
||||
|
@ -1482,8 +1490,14 @@ class Coder:
|
|||
time.sleep(retry_delay)
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
interrupted = True
|
||||
break
|
||||
# This handles Ctrl+C from the terminal, which calls io.interrupt_input()
|
||||
interrupted = True # Mark as interrupted
|
||||
self.keyboard_interrupt() # Handle the Ctrl+C signal
|
||||
break # Exit the retry loop
|
||||
except AiderAbortException: # This is the exception raised by show_send_output_stream
|
||||
# This handles abort signals from the IO layer (e.g., Connector)
|
||||
interrupted = True # Mark as interrupted
|
||||
break # Exit the retry loop
|
||||
except FinishReasonLength:
|
||||
# We hit the output limit!
|
||||
if not self.main_model.info.get("supports_assistant_prefill"):
|
||||
|
@ -1506,6 +1520,7 @@ class Coder:
|
|||
self.event("message_send_exception", exception=str(err))
|
||||
return
|
||||
finally:
|
||||
# This block always runs, even if an exception occurred or the loop broke (e.g., due to interrupt)
|
||||
if self.mdstream:
|
||||
self.live_incremental_response(True)
|
||||
self.mdstream = None
|
||||
|
@ -1513,34 +1528,36 @@ class Coder:
|
|||
# Ensure any waiting spinner is stopped
|
||||
self._stop_waiting_spinner()
|
||||
|
||||
# Capture the final partial content regardless of how it ended
|
||||
self.partial_response_content = self.get_multi_response_content_in_progress(True)
|
||||
self.remove_reasoning_content()
|
||||
self.multi_response_content = ""
|
||||
|
||||
###
|
||||
# print()
|
||||
# print("=" * 20)
|
||||
# dump(self.partial_response_content)
|
||||
# Add the assistant's reply (even if partial/interrupted) to cur_messages
|
||||
self.add_assistant_reply_to_cur_messages(interrupted=interrupted) # Pass interrupted status
|
||||
|
||||
self.io.tool_output()
|
||||
# Check the interrupted flag *after* the finally block
|
||||
if interrupted:
|
||||
# If interrupted, we stop here and don't proceed with post-processing
|
||||
self.io.tool_output("Response interrupted.")
|
||||
return # Exit send_message
|
||||
|
||||
# If we reach here, the stream completed without interruption or context exhaustion
|
||||
|
||||
self.io.tool_output() # Add a newline after the assistant output
|
||||
|
||||
self.show_usage_report()
|
||||
|
||||
self.add_assistant_reply_to_cur_messages()
|
||||
# The assistant reply was already added to cur_messages in the finally block
|
||||
|
||||
if exhausted:
|
||||
if self.cur_messages and self.cur_messages[-1]["role"] == "user":
|
||||
self.cur_messages += [
|
||||
dict(
|
||||
role="assistant",
|
||||
content="FinishReasonLength exception: you sent too many tokens",
|
||||
),
|
||||
]
|
||||
|
||||
# This case should now be handled before the main post-processing
|
||||
# but keeping the check here for safety/clarity.
|
||||
self.show_exhausted_error()
|
||||
self.num_exhausted_context_windows += 1
|
||||
return
|
||||
|
||||
# Proceed with post-processing only if not interrupted
|
||||
if self.partial_response_function_call:
|
||||
args = self.parse_partial_args()
|
||||
if args:
|
||||
|
@ -1552,31 +1569,35 @@ class Coder:
|
|||
else:
|
||||
content = ""
|
||||
|
||||
if not interrupted:
|
||||
add_rel_files_message = self.check_for_file_mentions(content)
|
||||
if add_rel_files_message:
|
||||
if self.reflected_message:
|
||||
self.reflected_message += "\n\n" + add_rel_files_message
|
||||
else:
|
||||
self.reflected_message = add_rel_files_message
|
||||
return
|
||||
|
||||
try:
|
||||
if self.reply_completed():
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
interrupted = True
|
||||
|
||||
if interrupted:
|
||||
if self.cur_messages and self.cur_messages[-1]["role"] == "user":
|
||||
self.cur_messages[-1]["content"] += "\n^C KeyboardInterrupt"
|
||||
# Check for file mentions in the final content
|
||||
add_rel_files_message = self.check_for_file_mentions(content)
|
||||
if add_rel_files_message:
|
||||
# If files were mentioned and added, set reflected_message
|
||||
if self.reflected_message:
|
||||
self.reflected_message += "\n\n" + add_rel_files_message
|
||||
else:
|
||||
self.cur_messages += [dict(role="user", content="^C KeyboardInterrupt")]
|
||||
self.cur_messages += [
|
||||
dict(role="assistant", content="I see that you interrupted my previous reply.")
|
||||
]
|
||||
self.reflected_message = add_rel_files_message
|
||||
# Return here to trigger reflection in run_one
|
||||
return
|
||||
|
||||
try:
|
||||
# Check if the reply completed successfully (subclass specific)
|
||||
if self.reply_completed():
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
# Handle Ctrl+C during reply_completed (less likely but possible)
|
||||
interrupted = True
|
||||
self.keyboard_interrupt()
|
||||
# Fall through to the interrupted handling below
|
||||
|
||||
# If interrupted during reply_completed or file mention check
|
||||
if interrupted:
|
||||
# The message history already has the partial content + "(interrupted)"
|
||||
# No need to add another message here.
|
||||
self.io.tool_output("Response interrupted.")
|
||||
return # Exit send_message
|
||||
|
||||
# Apply edits only if not interrupted and reply completed successfully
|
||||
edited = self.apply_updates()
|
||||
|
||||
if edited:
|
||||
|
@ -1586,11 +1607,14 @@ class Coder:
|
|||
if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"):
|
||||
saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo
|
||||
|
||||
# Move messages back only after successful edits and commit
|
||||
self.move_back_cur_messages(saved_message)
|
||||
|
||||
# Check for reflections again after applying edits
|
||||
if self.reflected_message:
|
||||
return
|
||||
|
||||
# Run lint/test only if edits were applied and not interrupted
|
||||
if edited and self.auto_lint:
|
||||
lint_errors = self.lint_edited(edited)
|
||||
self.auto_commit(edited, context="Ran the linter")
|
||||
|
@ -1601,13 +1625,16 @@ class Coder:
|
|||
self.reflected_message = lint_errors
|
||||
return
|
||||
|
||||
# Run shell commands only if not interrupted
|
||||
shared_output = self.run_shell_commands()
|
||||
if shared_output:
|
||||
# Add shell command output to chat history
|
||||
self.cur_messages += [
|
||||
dict(role="user", content=shared_output),
|
||||
dict(role="assistant", content="Ok"),
|
||||
]
|
||||
|
||||
# Run tests only if edits were applied and not interrupted
|
||||
if edited and self.auto_test:
|
||||
test_errors = self.commands.cmd_test(self.test_cmd)
|
||||
self.test_outcome = not test_errors
|
||||
|
@ -1694,10 +1721,21 @@ class Coder:
|
|||
"""Cleanup when the Coder object is destroyed."""
|
||||
self.ok_to_warm_cache = False
|
||||
|
||||
def add_assistant_reply_to_cur_messages(self):
|
||||
if self.partial_response_content:
|
||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||
if self.partial_response_function_call:
|
||||
def add_assistant_reply_to_cur_messages(self, interrupted=False):
|
||||
# Add the partial or complete response content to the current messages
|
||||
# This happens in the finally block, so it includes interrupted content
|
||||
content_to_add = self.partial_response_content
|
||||
if interrupted and content_to_add: # Use the passed interrupted flag
|
||||
content_to_add += " (interrupted)" # Indicate interruption in history
|
||||
|
||||
if content_to_add:
|
||||
self.cur_messages += [dict(role="assistant", content=content_to_add)]
|
||||
# Only add function call if not interrupted, as it implies completion
|
||||
if self.partial_response_function_call and not interrupted:
|
||||
# Note: Function calls are not currently interrupted by the IO layer check
|
||||
# as they don't stream content chunk by chunk in the same way.
|
||||
# If function calls needed interruption, the LiteLLM call itself would need
|
||||
# to be cancellable, which is provider/LiteLLM dependent.
|
||||
self.cur_messages += [
|
||||
dict(
|
||||
role="assistant",
|
||||
|
@ -1706,6 +1744,7 @@ class Coder:
|
|||
)
|
||||
]
|
||||
|
||||
|
||||
def get_file_mentions(self, content, ignore_current=False):
|
||||
words = set(word for word in content.split())
|
||||
|
||||
|
@ -1896,6 +1935,11 @@ class Coder:
|
|||
received_content = False
|
||||
|
||||
for chunk in completion:
|
||||
# Check for interruption at the start of processing each chunk
|
||||
if self.io.is_interrupted():
|
||||
self._stop_waiting_spinner() # Stop spinner on abort
|
||||
raise AiderAbortException("Stream interrupted by user.")
|
||||
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
|
|
|
@ -514,6 +514,13 @@ class InputOutput:
|
|||
self.interrupted = True
|
||||
self.prompt_session.app.exit()
|
||||
|
||||
def is_interrupted(self):
|
||||
"""
|
||||
Check if the current operation should be interrupted.
|
||||
Implementations should override this based on their environment.
|
||||
"""
|
||||
return False # Default implementation
|
||||
|
||||
def get_input(
|
||||
self,
|
||||
root,
|
||||
|
|
|
@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
|
|||
import git
|
||||
|
||||
from aider.coders import Coder
|
||||
from aider.coders.base_coder import FinishReasonLength, UnknownEditFormat
|
||||
from aider.coders.base_coder import AiderAbortException, FinishReasonLength, UnknownEditFormat
|
||||
from aider.dump import dump # noqa: F401
|
||||
from aider.io import InputOutput
|
||||
from aider.models import Model
|
||||
|
@ -1433,6 +1433,257 @@ This command will print 'Hello, World!' to the console."""
|
|||
# (because user rejected the changes)
|
||||
mock_editor.run.assert_not_called()
|
||||
|
||||
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||
def test_send_message_handles_abort_exception(
|
||||
self,
|
||||
mock_stop_spinner,
|
||||
mock_show_usage,
|
||||
mock_reply_completed,
|
||||
mock_check_mentions,
|
||||
mock_run_shell,
|
||||
mock_apply_updates,
|
||||
):
|
||||
"""Test that send_message handles AiderAbortException correctly."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Mock the send method to raise AiderAbortException
|
||||
def mock_send(*args, **kwargs):
|
||||
# Simulate some partial content before abort
|
||||
coder.partial_response_content = "Partial content received."
|
||||
raise AiderAbortException("Test abort")
|
||||
|
||||
coder.send = MagicMock(side_effect=mock_send)
|
||||
|
||||
# Mock post-processing methods to ensure they are NOT called
|
||||
mock_apply_updates.return_value = set()
|
||||
mock_run_shell.return_value = ""
|
||||
mock_check_mentions.return_value = None
|
||||
mock_reply_completed.return_value = False
|
||||
mock_show_usage.return_value = None
|
||||
|
||||
# Run send_message (it's a generator, so iterate through it)
|
||||
list(coder.send_message("Test message"))
|
||||
|
||||
# Assert that the partial content was captured and added to history
|
||||
self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"])
|
||||
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||
|
||||
# Assert that post-processing methods were NOT called
|
||||
mock_apply_updates.assert_not_called()
|
||||
mock_run_shell.assert_not_called()
|
||||
mock_check_mentions.assert_not_called()
|
||||
mock_reply_completed.assert_not_called()
|
||||
mock_show_usage.assert_not_called() # show_usage_report should be skipped on interruption
|
||||
|
||||
# Assert spinner was stopped
|
||||
mock_stop_spinner.assert_called()
|
||||
|
||||
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||
@patch("aider.io.InputOutput.is_interrupted", return_value=True) # Simulate IO signal
|
||||
def test_streaming_interruption_raises_exception(
|
||||
self,
|
||||
mock_is_interrupted,
|
||||
mock_stop_spinner,
|
||||
mock_show_usage,
|
||||
mock_reply_completed,
|
||||
mock_check_mentions,
|
||||
mock_run_shell,
|
||||
mock_apply_updates,
|
||||
):
|
||||
"""Test that streaming is interrupted when io.is_interrupted() is True."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io, stream=True)
|
||||
|
||||
# Mock the underlying LiteLLM completion stream
|
||||
mock_completion_stream = MagicMock()
|
||||
# Simulate a few chunks before interruption
|
||||
mock_completion_stream.__iter__.return_value = [
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk1"))]),
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk2"))]),
|
||||
# The third iteration will check io.is_interrupted() and raise
|
||||
]
|
||||
|
||||
# Mock the send method to return the mock stream
|
||||
coder.send = MagicMock(return_value=mock_completion_stream)
|
||||
|
||||
# Mock post-processing methods to ensure they are NOT called
|
||||
mock_apply_updates.return_value = set()
|
||||
mock_run_shell.return_value = ""
|
||||
mock_check_mentions.return_value = None
|
||||
mock_reply_completed.return_value = False
|
||||
mock_show_usage.return_value = None
|
||||
|
||||
# Run send_message (it's a generator, so iterate through it)
|
||||
# Expect AiderAbortException to be raised by show_send_output_stream
|
||||
# which is called by send, which is yielded from by send_message.
|
||||
# The exception should be caught within send_message.
|
||||
list(coder.send_message("Test streaming interruption"))
|
||||
|
||||
# Assert that the interrupted flag was set within send_message's exception handler
|
||||
# We can't directly check the local 'interrupted' flag in send_message,
|
||||
# but we can check that the io object's flag was used (which we mocked)
|
||||
# and that post-processing was skipped.
|
||||
self.assertTrue(mock_is_interrupted.called) # Verify the mock was checked
|
||||
|
||||
# Assert that the partial content was captured and added to history
|
||||
# The content should include chunks received before the interrupt
|
||||
self.assertIn("chunk1chunk2 (interrupted)", coder.cur_messages[-1]["content"])
|
||||
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||
|
||||
# Assert that post-processing methods were NOT called
|
||||
mock_apply_updates.assert_not_called()
|
||||
mock_run_shell.assert_not_called()
|
||||
mock_check_mentions.assert_not_called()
|
||||
mock_reply_completed.assert_not_called()
|
||||
mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check
|
||||
|
||||
# Assert spinner was stopped
|
||||
mock_stop_spinner.assert_called()
|
||||
|
||||
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||
@patch("aider.coders.base_coder.Coder.keyboard_interrupt") # Mock the handler
|
||||
def test_send_message_handles_keyboard_interrupt(
|
||||
self,
|
||||
mock_keyboard_interrupt,
|
||||
mock_stop_spinner,
|
||||
mock_show_usage,
|
||||
mock_reply_completed,
|
||||
mock_check_mentions,
|
||||
mock_run_shell,
|
||||
mock_apply_updates,
|
||||
):
|
||||
"""Test that send_message handles KeyboardInterrupt correctly."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Mock the send method to raise KeyboardInterrupt
|
||||
def mock_send(*args, **kwargs):
|
||||
# Simulate some partial content before interrupt
|
||||
coder.partial_response_content = "Partial content received."
|
||||
raise KeyboardInterrupt("Test interrupt")
|
||||
|
||||
coder.send = MagicMock(side_effect=mock_send)
|
||||
|
||||
# Mock post-processing methods to ensure they are NOT called
|
||||
mock_apply_updates.return_value = set()
|
||||
mock_run_shell.return_value = ""
|
||||
mock_check_mentions.return_value = None
|
||||
mock_reply_completed.return_value = False
|
||||
mock_show_usage.return_value = None
|
||||
|
||||
# Run send_message (it's a generator, so iterate through it)
|
||||
list(coder.send_message("Test message"))
|
||||
|
||||
# Assert that keyboard_interrupt handler was called
|
||||
mock_keyboard_interrupt.assert_called_once()
|
||||
|
||||
# Assert that the partial content was captured and added to history
|
||||
self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"])
|
||||
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||
|
||||
# Assert that post-processing methods were NOT called
|
||||
mock_apply_updates.assert_not_called()
|
||||
mock_run_shell.assert_not_called()
|
||||
mock_check_mentions.assert_not_called()
|
||||
mock_reply_completed.assert_not_called()
|
||||
mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check
|
||||
|
||||
# Assert spinner was stopped
|
||||
mock_stop_spinner.assert_called()
|
||||
|
||||
def test_add_assistant_reply_to_cur_messages_interrupted(self):
|
||||
"""Test that add_assistant_reply_to_cur_messages marks content as interrupted."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Simulate partial content and interrupted state
|
||||
coder.partial_response_content = "This is partial content."
|
||||
coder.partial_response_function_call = None # No function call
|
||||
|
||||
# Add the reply, simulating interruption
|
||||
coder.add_assistant_reply_to_cur_messages(interrupted=True)
|
||||
|
||||
# Check the last message in cur_messages
|
||||
self.assertEqual(len(coder.cur_messages), 1)
|
||||
last_msg = coder.cur_messages[-1]
|
||||
self.assertEqual(last_msg["role"], "assistant")
|
||||
self.assertEqual(last_msg["content"], "This is partial content. (interrupted)")
|
||||
self.assertNotIn("function_call", last_msg)
|
||||
|
||||
# Simulate partial content with function call and interrupted state
|
||||
coder.cur_messages = [] # Reset messages
|
||||
coder.partial_response_content = "This is partial content before func call."
|
||||
coder.partial_response_function_call = {"name": "test_func", "arguments": "{}"}
|
||||
|
||||
# Add the reply, simulating interruption
|
||||
coder.add_assistant_reply_to_cur_messages(interrupted=True)
|
||||
|
||||
# Check the last message in cur_messages
|
||||
self.assertEqual(len(coder.cur_messages), 1)
|
||||
last_msg = coder.cur_messages[-1]
|
||||
self.assertEqual(last_msg["role"], "assistant")
|
||||
# Content should be marked as interrupted
|
||||
self.assertEqual(last_msg["content"], "This is partial content before func call. (interrupted)")
|
||||
# Function call should NOT be added if interrupted
|
||||
self.assertNotIn("function_call", last_msg)
|
||||
|
||||
def test_add_assistant_reply_to_cur_messages_not_interrupted(self):
|
||||
"""Test that add_assistant_reply_to_cur_messages works normally when not interrupted."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Simulate complete content and no interruption
|
||||
coder.partial_response_content = "This is the final content."
|
||||
coder.partial_response_function_call = None # No function call
|
||||
|
||||
# Add the reply, not interrupted
|
||||
coder.add_assistant_reply_to_cur_messages(interrupted=False)
|
||||
|
||||
# Check the last message in cur_messages
|
||||
self.assertEqual(len(coder.cur_messages), 1)
|
||||
last_msg = coder.cur_messages[-1]
|
||||
self.assertEqual(last_msg["role"], "assistant")
|
||||
self.assertEqual(last_msg["content"], "This is the final content.")
|
||||
self.assertNotIn("function_call", last_msg)
|
||||
|
||||
# Simulate complete content with function call and no interruption
|
||||
coder.cur_messages = [] # Reset messages
|
||||
coder.partial_response_content = "This is content before func call."
|
||||
coder.partial_response_function_call = {"name": "test_func", "arguments": '{"arg": "val"}'}
|
||||
|
||||
# Add the reply, not interrupted
|
||||
coder.add_assistant_reply_to_cur_messages(interrupted=False)
|
||||
|
||||
# Check the last message in cur_messages
|
||||
self.assertEqual(len(coder.cur_messages), 1)
|
||||
last_msg = coder.cur_messages[-1]
|
||||
self.assertEqual(last_msg["role"], "assistant")
|
||||
self.assertEqual(last_msg["content"], "This is content before func call.")
|
||||
# Function call SHOULD be added if not interrupted
|
||||
self.assertIn("function_call", last_msg)
|
||||
self.assertEqual(last_msg["function_call"], {"name": "test_func", "arguments": '{"arg": "val"}'})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -341,6 +341,11 @@ class TestInputOutput(unittest.TestCase):
|
|||
self.assertEqual(mock_input.call_count, 2)
|
||||
self.assertNotIn(("Do you want to proceed?", None), io.never_prompts)
|
||||
|
||||
def test_is_interrupted_default(self):
|
||||
"""Test that is_interrupted returns False by default."""
|
||||
io = InputOutput()
|
||||
self.assertFalse(io.is_interrupted())
|
||||
|
||||
|
||||
class TestInputOutputMultilineMode(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue