mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 16:54:59 +00:00
Merge 40b75bc57f
into 3caab85931
This commit is contained in:
commit
8fb5512e69
4 changed files with 352 additions and 45 deletions
|
@ -69,6 +69,10 @@ class MissingAPIKeyError(ValueError):
|
||||||
class FinishReasonLength(Exception):
|
class FinishReasonLength(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Add this new exception
|
||||||
|
class AiderAbortException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def wrap_fence(name):
|
def wrap_fence(name):
|
||||||
return f"<{name}>", f"</{name}>"
|
return f"<{name}>", f"</{name}>"
|
||||||
|
@ -864,6 +868,7 @@ class Coder:
|
||||||
self.test_outcome = None
|
self.test_outcome = None
|
||||||
self.shell_commands = []
|
self.shell_commands = []
|
||||||
self.message_cost = 0
|
self.message_cost = 0
|
||||||
|
self.partial_response_function_call = None # Initialize function call attribute
|
||||||
|
|
||||||
if self.repo:
|
if self.repo:
|
||||||
self.commit_before_message.append(self.repo.get_head_commit_sha())
|
self.commit_before_message.append(self.repo.get_head_commit_sha())
|
||||||
|
@ -1417,6 +1422,8 @@ class Coder:
|
||||||
# Notify IO that LLM processing is starting
|
# Notify IO that LLM processing is starting
|
||||||
self.io.llm_started()
|
self.io.llm_started()
|
||||||
|
|
||||||
|
self.partial_response_function_call = None # Initialize function call attribute here
|
||||||
|
|
||||||
self.cur_messages += [
|
self.cur_messages += [
|
||||||
dict(role="user", content=inp),
|
dict(role="user", content=inp),
|
||||||
]
|
]
|
||||||
|
@ -1447,10 +1454,11 @@ class Coder:
|
||||||
|
|
||||||
self.usage_report = None
|
self.usage_report = None
|
||||||
exhausted = False
|
exhausted = False
|
||||||
interrupted = False
|
interrupted = False # Flag to track if AiderAbortException was caught
|
||||||
try:
|
try: # This try block handles retries and different exception types
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
# This is where AiderAbortException can be raised by show_send_output_stream
|
||||||
yield from self.send(messages, functions=self.functions)
|
yield from self.send(messages, functions=self.functions)
|
||||||
break
|
break
|
||||||
except litellm_ex.exceptions_tuple() as err:
|
except litellm_ex.exceptions_tuple() as err:
|
||||||
|
@ -1482,8 +1490,14 @@ class Coder:
|
||||||
time.sleep(retry_delay)
|
time.sleep(retry_delay)
|
||||||
continue
|
continue
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
interrupted = True
|
# This handles Ctrl+C from the terminal, which calls io.interrupt_input()
|
||||||
break
|
interrupted = True # Mark as interrupted
|
||||||
|
self.keyboard_interrupt() # Handle the Ctrl+C signal
|
||||||
|
break # Exit the retry loop
|
||||||
|
except AiderAbortException: # This is the exception raised by show_send_output_stream
|
||||||
|
# This handles abort signals from the IO layer (e.g., Connector)
|
||||||
|
interrupted = True # Mark as interrupted
|
||||||
|
break # Exit the retry loop
|
||||||
except FinishReasonLength:
|
except FinishReasonLength:
|
||||||
# We hit the output limit!
|
# We hit the output limit!
|
||||||
if not self.main_model.info.get("supports_assistant_prefill"):
|
if not self.main_model.info.get("supports_assistant_prefill"):
|
||||||
|
@ -1506,6 +1520,7 @@ class Coder:
|
||||||
self.event("message_send_exception", exception=str(err))
|
self.event("message_send_exception", exception=str(err))
|
||||||
return
|
return
|
||||||
finally:
|
finally:
|
||||||
|
# This block always runs, even if an exception occurred or the loop broke (e.g., due to interrupt)
|
||||||
if self.mdstream:
|
if self.mdstream:
|
||||||
self.live_incremental_response(True)
|
self.live_incremental_response(True)
|
||||||
self.mdstream = None
|
self.mdstream = None
|
||||||
|
@ -1513,34 +1528,36 @@ class Coder:
|
||||||
# Ensure any waiting spinner is stopped
|
# Ensure any waiting spinner is stopped
|
||||||
self._stop_waiting_spinner()
|
self._stop_waiting_spinner()
|
||||||
|
|
||||||
|
# Capture the final partial content regardless of how it ended
|
||||||
self.partial_response_content = self.get_multi_response_content_in_progress(True)
|
self.partial_response_content = self.get_multi_response_content_in_progress(True)
|
||||||
self.remove_reasoning_content()
|
self.remove_reasoning_content()
|
||||||
self.multi_response_content = ""
|
self.multi_response_content = ""
|
||||||
|
|
||||||
###
|
# Add the assistant's reply (even if partial/interrupted) to cur_messages
|
||||||
# print()
|
self.add_assistant_reply_to_cur_messages(interrupted=interrupted) # Pass interrupted status
|
||||||
# print("=" * 20)
|
|
||||||
# dump(self.partial_response_content)
|
|
||||||
|
|
||||||
self.io.tool_output()
|
# Check the interrupted flag *after* the finally block
|
||||||
|
if interrupted:
|
||||||
|
# If interrupted, we stop here and don't proceed with post-processing
|
||||||
|
self.io.tool_output("Response interrupted.")
|
||||||
|
return # Exit send_message
|
||||||
|
|
||||||
|
# If we reach here, the stream completed without interruption or context exhaustion
|
||||||
|
|
||||||
|
self.io.tool_output() # Add a newline after the assistant output
|
||||||
|
|
||||||
self.show_usage_report()
|
self.show_usage_report()
|
||||||
|
|
||||||
self.add_assistant_reply_to_cur_messages()
|
# The assistant reply was already added to cur_messages in the finally block
|
||||||
|
|
||||||
if exhausted:
|
if exhausted:
|
||||||
if self.cur_messages and self.cur_messages[-1]["role"] == "user":
|
# This case should now be handled before the main post-processing
|
||||||
self.cur_messages += [
|
# but keeping the check here for safety/clarity.
|
||||||
dict(
|
|
||||||
role="assistant",
|
|
||||||
content="FinishReasonLength exception: you sent too many tokens",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
self.show_exhausted_error()
|
self.show_exhausted_error()
|
||||||
self.num_exhausted_context_windows += 1
|
self.num_exhausted_context_windows += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Proceed with post-processing only if not interrupted
|
||||||
if self.partial_response_function_call:
|
if self.partial_response_function_call:
|
||||||
args = self.parse_partial_args()
|
args = self.parse_partial_args()
|
||||||
if args:
|
if args:
|
||||||
|
@ -1552,31 +1569,35 @@ class Coder:
|
||||||
else:
|
else:
|
||||||
content = ""
|
content = ""
|
||||||
|
|
||||||
if not interrupted:
|
# Check for file mentions in the final content
|
||||||
add_rel_files_message = self.check_for_file_mentions(content)
|
add_rel_files_message = self.check_for_file_mentions(content)
|
||||||
if add_rel_files_message:
|
if add_rel_files_message:
|
||||||
if self.reflected_message:
|
# If files were mentioned and added, set reflected_message
|
||||||
self.reflected_message += "\n\n" + add_rel_files_message
|
if self.reflected_message:
|
||||||
else:
|
self.reflected_message += "\n\n" + add_rel_files_message
|
||||||
self.reflected_message = add_rel_files_message
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.reply_completed():
|
|
||||||
return
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
interrupted = True
|
|
||||||
|
|
||||||
if interrupted:
|
|
||||||
if self.cur_messages and self.cur_messages[-1]["role"] == "user":
|
|
||||||
self.cur_messages[-1]["content"] += "\n^C KeyboardInterrupt"
|
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="user", content="^C KeyboardInterrupt")]
|
self.reflected_message = add_rel_files_message
|
||||||
self.cur_messages += [
|
# Return here to trigger reflection in run_one
|
||||||
dict(role="assistant", content="I see that you interrupted my previous reply.")
|
|
||||||
]
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if the reply completed successfully (subclass specific)
|
||||||
|
if self.reply_completed():
|
||||||
|
return
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# Handle Ctrl+C during reply_completed (less likely but possible)
|
||||||
|
interrupted = True
|
||||||
|
self.keyboard_interrupt()
|
||||||
|
# Fall through to the interrupted handling below
|
||||||
|
|
||||||
|
# If interrupted during reply_completed or file mention check
|
||||||
|
if interrupted:
|
||||||
|
# The message history already has the partial content + "(interrupted)"
|
||||||
|
# No need to add another message here.
|
||||||
|
self.io.tool_output("Response interrupted.")
|
||||||
|
return # Exit send_message
|
||||||
|
|
||||||
|
# Apply edits only if not interrupted and reply completed successfully
|
||||||
edited = self.apply_updates()
|
edited = self.apply_updates()
|
||||||
|
|
||||||
if edited:
|
if edited:
|
||||||
|
@ -1586,11 +1607,14 @@ class Coder:
|
||||||
if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"):
|
if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"):
|
||||||
saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo
|
saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo
|
||||||
|
|
||||||
|
# Move messages back only after successful edits and commit
|
||||||
self.move_back_cur_messages(saved_message)
|
self.move_back_cur_messages(saved_message)
|
||||||
|
|
||||||
|
# Check for reflections again after applying edits
|
||||||
if self.reflected_message:
|
if self.reflected_message:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Run lint/test only if edits were applied and not interrupted
|
||||||
if edited and self.auto_lint:
|
if edited and self.auto_lint:
|
||||||
lint_errors = self.lint_edited(edited)
|
lint_errors = self.lint_edited(edited)
|
||||||
self.auto_commit(edited, context="Ran the linter")
|
self.auto_commit(edited, context="Ran the linter")
|
||||||
|
@ -1601,13 +1625,16 @@ class Coder:
|
||||||
self.reflected_message = lint_errors
|
self.reflected_message = lint_errors
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Run shell commands only if not interrupted
|
||||||
shared_output = self.run_shell_commands()
|
shared_output = self.run_shell_commands()
|
||||||
if shared_output:
|
if shared_output:
|
||||||
|
# Add shell command output to chat history
|
||||||
self.cur_messages += [
|
self.cur_messages += [
|
||||||
dict(role="user", content=shared_output),
|
dict(role="user", content=shared_output),
|
||||||
dict(role="assistant", content="Ok"),
|
dict(role="assistant", content="Ok"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Run tests only if edits were applied and not interrupted
|
||||||
if edited and self.auto_test:
|
if edited and self.auto_test:
|
||||||
test_errors = self.commands.cmd_test(self.test_cmd)
|
test_errors = self.commands.cmd_test(self.test_cmd)
|
||||||
self.test_outcome = not test_errors
|
self.test_outcome = not test_errors
|
||||||
|
@ -1694,10 +1721,21 @@ class Coder:
|
||||||
"""Cleanup when the Coder object is destroyed."""
|
"""Cleanup when the Coder object is destroyed."""
|
||||||
self.ok_to_warm_cache = False
|
self.ok_to_warm_cache = False
|
||||||
|
|
||||||
def add_assistant_reply_to_cur_messages(self):
|
def add_assistant_reply_to_cur_messages(self, interrupted=False):
|
||||||
if self.partial_response_content:
|
# Add the partial or complete response content to the current messages
|
||||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
# This happens in the finally block, so it includes interrupted content
|
||||||
if self.partial_response_function_call:
|
content_to_add = self.partial_response_content
|
||||||
|
if interrupted and content_to_add: # Use the passed interrupted flag
|
||||||
|
content_to_add += " (interrupted)" # Indicate interruption in history
|
||||||
|
|
||||||
|
if content_to_add:
|
||||||
|
self.cur_messages += [dict(role="assistant", content=content_to_add)]
|
||||||
|
# Only add function call if not interrupted, as it implies completion
|
||||||
|
if self.partial_response_function_call and not interrupted:
|
||||||
|
# Note: Function calls are not currently interrupted by the IO layer check
|
||||||
|
# as they don't stream content chunk by chunk in the same way.
|
||||||
|
# If function calls needed interruption, the LiteLLM call itself would need
|
||||||
|
# to be cancellable, which is provider/LiteLLM dependent.
|
||||||
self.cur_messages += [
|
self.cur_messages += [
|
||||||
dict(
|
dict(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
@ -1706,6 +1744,7 @@ class Coder:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_file_mentions(self, content, ignore_current=False):
|
def get_file_mentions(self, content, ignore_current=False):
|
||||||
words = set(word for word in content.split())
|
words = set(word for word in content.split())
|
||||||
|
|
||||||
|
@ -1896,6 +1935,11 @@ class Coder:
|
||||||
received_content = False
|
received_content = False
|
||||||
|
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
|
# Check for interruption at the start of processing each chunk
|
||||||
|
if self.io.is_interrupted():
|
||||||
|
self._stop_waiting_spinner() # Stop spinner on abort
|
||||||
|
raise AiderAbortException("Stream interrupted by user.")
|
||||||
|
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -514,6 +514,13 @@ class InputOutput:
|
||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
self.prompt_session.app.exit()
|
self.prompt_session.app.exit()
|
||||||
|
|
||||||
|
def is_interrupted(self):
|
||||||
|
"""
|
||||||
|
Check if the current operation should be interrupted.
|
||||||
|
Implementations should override this based on their environment.
|
||||||
|
"""
|
||||||
|
return False # Default implementation
|
||||||
|
|
||||||
def get_input(
|
def get_input(
|
||||||
self,
|
self,
|
||||||
root,
|
root,
|
||||||
|
|
|
@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from aider.coders import Coder
|
from aider.coders import Coder
|
||||||
from aider.coders.base_coder import FinishReasonLength, UnknownEditFormat
|
from aider.coders.base_coder import AiderAbortException, FinishReasonLength, UnknownEditFormat
|
||||||
from aider.dump import dump # noqa: F401
|
from aider.dump import dump # noqa: F401
|
||||||
from aider.io import InputOutput
|
from aider.io import InputOutput
|
||||||
from aider.models import Model
|
from aider.models import Model
|
||||||
|
@ -1433,6 +1433,257 @@ This command will print 'Hello, World!' to the console."""
|
||||||
# (because user rejected the changes)
|
# (because user rejected the changes)
|
||||||
mock_editor.run.assert_not_called()
|
mock_editor.run.assert_not_called()
|
||||||
|
|
||||||
|
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||||
|
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||||
|
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||||
|
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||||
|
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||||
|
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||||
|
def test_send_message_handles_abort_exception(
|
||||||
|
self,
|
||||||
|
mock_stop_spinner,
|
||||||
|
mock_show_usage,
|
||||||
|
mock_reply_completed,
|
||||||
|
mock_check_mentions,
|
||||||
|
mock_run_shell,
|
||||||
|
mock_apply_updates,
|
||||||
|
):
|
||||||
|
"""Test that send_message handles AiderAbortException correctly."""
|
||||||
|
with GitTemporaryDirectory():
|
||||||
|
io = InputOutput(yes=True)
|
||||||
|
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||||
|
|
||||||
|
# Mock the send method to raise AiderAbortException
|
||||||
|
def mock_send(*args, **kwargs):
|
||||||
|
# Simulate some partial content before abort
|
||||||
|
coder.partial_response_content = "Partial content received."
|
||||||
|
raise AiderAbortException("Test abort")
|
||||||
|
|
||||||
|
coder.send = MagicMock(side_effect=mock_send)
|
||||||
|
|
||||||
|
# Mock post-processing methods to ensure they are NOT called
|
||||||
|
mock_apply_updates.return_value = set()
|
||||||
|
mock_run_shell.return_value = ""
|
||||||
|
mock_check_mentions.return_value = None
|
||||||
|
mock_reply_completed.return_value = False
|
||||||
|
mock_show_usage.return_value = None
|
||||||
|
|
||||||
|
# Run send_message (it's a generator, so iterate through it)
|
||||||
|
list(coder.send_message("Test message"))
|
||||||
|
|
||||||
|
# Assert that the partial content was captured and added to history
|
||||||
|
self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"])
|
||||||
|
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||||
|
|
||||||
|
# Assert that post-processing methods were NOT called
|
||||||
|
mock_apply_updates.assert_not_called()
|
||||||
|
mock_run_shell.assert_not_called()
|
||||||
|
mock_check_mentions.assert_not_called()
|
||||||
|
mock_reply_completed.assert_not_called()
|
||||||
|
mock_show_usage.assert_not_called() # show_usage_report should be skipped on interruption
|
||||||
|
|
||||||
|
# Assert spinner was stopped
|
||||||
|
mock_stop_spinner.assert_called()
|
||||||
|
|
||||||
|
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||||
|
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||||
|
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||||
|
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||||
|
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||||
|
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||||
|
@patch("aider.io.InputOutput.is_interrupted", return_value=True) # Simulate IO signal
|
||||||
|
def test_streaming_interruption_raises_exception(
|
||||||
|
self,
|
||||||
|
mock_is_interrupted,
|
||||||
|
mock_stop_spinner,
|
||||||
|
mock_show_usage,
|
||||||
|
mock_reply_completed,
|
||||||
|
mock_check_mentions,
|
||||||
|
mock_run_shell,
|
||||||
|
mock_apply_updates,
|
||||||
|
):
|
||||||
|
"""Test that streaming is interrupted when io.is_interrupted() is True."""
|
||||||
|
with GitTemporaryDirectory():
|
||||||
|
io = InputOutput(yes=True)
|
||||||
|
coder = Coder.create(self.GPT35, "diff", io=io, stream=True)
|
||||||
|
|
||||||
|
# Mock the underlying LiteLLM completion stream
|
||||||
|
mock_completion_stream = MagicMock()
|
||||||
|
# Simulate a few chunks before interruption
|
||||||
|
mock_completion_stream.__iter__.return_value = [
|
||||||
|
MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk1"))]),
|
||||||
|
MagicMock(choices=[MagicMock(delta=MagicMock(content="chunk2"))]),
|
||||||
|
# The third iteration will check io.is_interrupted() and raise
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock the send method to return the mock stream
|
||||||
|
coder.send = MagicMock(return_value=mock_completion_stream)
|
||||||
|
|
||||||
|
# Mock post-processing methods to ensure they are NOT called
|
||||||
|
mock_apply_updates.return_value = set()
|
||||||
|
mock_run_shell.return_value = ""
|
||||||
|
mock_check_mentions.return_value = None
|
||||||
|
mock_reply_completed.return_value = False
|
||||||
|
mock_show_usage.return_value = None
|
||||||
|
|
||||||
|
# Run send_message (it's a generator, so iterate through it)
|
||||||
|
# Expect AiderAbortException to be raised by show_send_output_stream
|
||||||
|
# which is called by send, which is yielded from by send_message.
|
||||||
|
# The exception should be caught within send_message.
|
||||||
|
list(coder.send_message("Test streaming interruption"))
|
||||||
|
|
||||||
|
# Assert that the interrupted flag was set within send_message's exception handler
|
||||||
|
# We can't directly check the local 'interrupted' flag in send_message,
|
||||||
|
# but we can check that the io object's flag was used (which we mocked)
|
||||||
|
# and that post-processing was skipped.
|
||||||
|
self.assertTrue(mock_is_interrupted.called) # Verify the mock was checked
|
||||||
|
|
||||||
|
# Assert that the partial content was captured and added to history
|
||||||
|
# The content should include chunks received before the interrupt
|
||||||
|
self.assertIn("chunk1chunk2 (interrupted)", coder.cur_messages[-1]["content"])
|
||||||
|
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||||
|
|
||||||
|
# Assert that post-processing methods were NOT called
|
||||||
|
mock_apply_updates.assert_not_called()
|
||||||
|
mock_run_shell.assert_not_called()
|
||||||
|
mock_check_mentions.assert_not_called()
|
||||||
|
mock_reply_completed.assert_not_called()
|
||||||
|
mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check
|
||||||
|
|
||||||
|
# Assert spinner was stopped
|
||||||
|
mock_stop_spinner.assert_called()
|
||||||
|
|
||||||
|
@patch("aider.coders.base_coder.Coder.apply_updates")
|
||||||
|
@patch("aider.coders.base_coder.Coder.run_shell_commands")
|
||||||
|
@patch("aider.coders.base_coder.Coder.check_for_file_mentions")
|
||||||
|
@patch("aider.coders.base_coder.Coder.reply_completed")
|
||||||
|
@patch("aider.coders.base_coder.Coder.show_usage_report")
|
||||||
|
@patch("aider.coders.base_coder.Coder._stop_waiting_spinner")
|
||||||
|
@patch("aider.coders.base_coder.Coder.keyboard_interrupt") # Mock the handler
|
||||||
|
def test_send_message_handles_keyboard_interrupt(
|
||||||
|
self,
|
||||||
|
mock_keyboard_interrupt,
|
||||||
|
mock_stop_spinner,
|
||||||
|
mock_show_usage,
|
||||||
|
mock_reply_completed,
|
||||||
|
mock_check_mentions,
|
||||||
|
mock_run_shell,
|
||||||
|
mock_apply_updates,
|
||||||
|
):
|
||||||
|
"""Test that send_message handles KeyboardInterrupt correctly."""
|
||||||
|
with GitTemporaryDirectory():
|
||||||
|
io = InputOutput(yes=True)
|
||||||
|
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||||
|
|
||||||
|
# Mock the send method to raise KeyboardInterrupt
|
||||||
|
def mock_send(*args, **kwargs):
|
||||||
|
# Simulate some partial content before interrupt
|
||||||
|
coder.partial_response_content = "Partial content received."
|
||||||
|
raise KeyboardInterrupt("Test interrupt")
|
||||||
|
|
||||||
|
coder.send = MagicMock(side_effect=mock_send)
|
||||||
|
|
||||||
|
# Mock post-processing methods to ensure they are NOT called
|
||||||
|
mock_apply_updates.return_value = set()
|
||||||
|
mock_run_shell.return_value = ""
|
||||||
|
mock_check_mentions.return_value = None
|
||||||
|
mock_reply_completed.return_value = False
|
||||||
|
mock_show_usage.return_value = None
|
||||||
|
|
||||||
|
# Run send_message (it's a generator, so iterate through it)
|
||||||
|
list(coder.send_message("Test message"))
|
||||||
|
|
||||||
|
# Assert that keyboard_interrupt handler was called
|
||||||
|
mock_keyboard_interrupt.assert_called_once()
|
||||||
|
|
||||||
|
# Assert that the partial content was captured and added to history
|
||||||
|
self.assertIn("Partial content received. (interrupted)", coder.cur_messages[-1]["content"])
|
||||||
|
self.assertEqual(coder.cur_messages[-1]["role"], "assistant")
|
||||||
|
|
||||||
|
# Assert that post-processing methods were NOT called
|
||||||
|
mock_apply_updates.assert_not_called()
|
||||||
|
mock_run_shell.assert_not_called()
|
||||||
|
mock_check_mentions.assert_not_called()
|
||||||
|
mock_reply_completed.assert_not_called()
|
||||||
|
mock_show_usage.assert_called_once() # show_usage_report is called after the finally block, but before the interrupted check
|
||||||
|
|
||||||
|
# Assert spinner was stopped
|
||||||
|
mock_stop_spinner.assert_called()
|
||||||
|
|
||||||
|
def test_add_assistant_reply_to_cur_messages_interrupted(self):
|
||||||
|
"""Test that add_assistant_reply_to_cur_messages marks content as interrupted."""
|
||||||
|
with GitTemporaryDirectory():
|
||||||
|
io = InputOutput(yes=True)
|
||||||
|
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||||
|
|
||||||
|
# Simulate partial content and interrupted state
|
||||||
|
coder.partial_response_content = "This is partial content."
|
||||||
|
coder.partial_response_function_call = None # No function call
|
||||||
|
|
||||||
|
# Add the reply, simulating interruption
|
||||||
|
coder.add_assistant_reply_to_cur_messages(interrupted=True)
|
||||||
|
|
||||||
|
# Check the last message in cur_messages
|
||||||
|
self.assertEqual(len(coder.cur_messages), 1)
|
||||||
|
last_msg = coder.cur_messages[-1]
|
||||||
|
self.assertEqual(last_msg["role"], "assistant")
|
||||||
|
self.assertEqual(last_msg["content"], "This is partial content. (interrupted)")
|
||||||
|
self.assertNotIn("function_call", last_msg)
|
||||||
|
|
||||||
|
# Simulate partial content with function call and interrupted state
|
||||||
|
coder.cur_messages = [] # Reset messages
|
||||||
|
coder.partial_response_content = "This is partial content before func call."
|
||||||
|
coder.partial_response_function_call = {"name": "test_func", "arguments": "{}"}
|
||||||
|
|
||||||
|
# Add the reply, simulating interruption
|
||||||
|
coder.add_assistant_reply_to_cur_messages(interrupted=True)
|
||||||
|
|
||||||
|
# Check the last message in cur_messages
|
||||||
|
self.assertEqual(len(coder.cur_messages), 1)
|
||||||
|
last_msg = coder.cur_messages[-1]
|
||||||
|
self.assertEqual(last_msg["role"], "assistant")
|
||||||
|
# Content should be marked as interrupted
|
||||||
|
self.assertEqual(last_msg["content"], "This is partial content before func call. (interrupted)")
|
||||||
|
# Function call should NOT be added if interrupted
|
||||||
|
self.assertNotIn("function_call", last_msg)
|
||||||
|
|
||||||
|
def test_add_assistant_reply_to_cur_messages_not_interrupted(self):
|
||||||
|
"""Test that add_assistant_reply_to_cur_messages works normally when not interrupted."""
|
||||||
|
with GitTemporaryDirectory():
|
||||||
|
io = InputOutput(yes=True)
|
||||||
|
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||||
|
|
||||||
|
# Simulate complete content and no interruption
|
||||||
|
coder.partial_response_content = "This is the final content."
|
||||||
|
coder.partial_response_function_call = None # No function call
|
||||||
|
|
||||||
|
# Add the reply, not interrupted
|
||||||
|
coder.add_assistant_reply_to_cur_messages(interrupted=False)
|
||||||
|
|
||||||
|
# Check the last message in cur_messages
|
||||||
|
self.assertEqual(len(coder.cur_messages), 1)
|
||||||
|
last_msg = coder.cur_messages[-1]
|
||||||
|
self.assertEqual(last_msg["role"], "assistant")
|
||||||
|
self.assertEqual(last_msg["content"], "This is the final content.")
|
||||||
|
self.assertNotIn("function_call", last_msg)
|
||||||
|
|
||||||
|
# Simulate complete content with function call and no interruption
|
||||||
|
coder.cur_messages = [] # Reset messages
|
||||||
|
coder.partial_response_content = "This is content before func call."
|
||||||
|
coder.partial_response_function_call = {"name": "test_func", "arguments": '{"arg": "val"}'}
|
||||||
|
|
||||||
|
# Add the reply, not interrupted
|
||||||
|
coder.add_assistant_reply_to_cur_messages(interrupted=False)
|
||||||
|
|
||||||
|
# Check the last message in cur_messages
|
||||||
|
self.assertEqual(len(coder.cur_messages), 1)
|
||||||
|
last_msg = coder.cur_messages[-1]
|
||||||
|
self.assertEqual(last_msg["role"], "assistant")
|
||||||
|
self.assertEqual(last_msg["content"], "This is content before func call.")
|
||||||
|
# Function call SHOULD be added if not interrupted
|
||||||
|
self.assertIn("function_call", last_msg)
|
||||||
|
self.assertEqual(last_msg["function_call"], {"name": "test_func", "arguments": '{"arg": "val"}'})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -341,6 +341,11 @@ class TestInputOutput(unittest.TestCase):
|
||||||
self.assertEqual(mock_input.call_count, 2)
|
self.assertEqual(mock_input.call_count, 2)
|
||||||
self.assertNotIn(("Do you want to proceed?", None), io.never_prompts)
|
self.assertNotIn(("Do you want to proceed?", None), io.never_prompts)
|
||||||
|
|
||||||
|
def test_is_interrupted_default(self):
|
||||||
|
"""Test that is_interrupted returns False by default."""
|
||||||
|
io = InputOutput()
|
||||||
|
self.assertFalse(io.is_interrupted())
|
||||||
|
|
||||||
|
|
||||||
class TestInputOutputMultilineMode(unittest.TestCase):
|
class TestInputOutputMultilineMode(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue