Respect Aider confirmation settings

This commit is contained in:
Quinlan Jager 2025-05-05 22:59:31 -07:00
parent 282b349080
commit 097620026f

View file

@ -1478,25 +1478,14 @@ class Coder:
self.reflected_message = add_rel_files_message
return
# Process any tools using MCP servers
tool_call_response = litellm.stream_chunk_builder(self.partial_response_tool_call)
if tool_call_response and self.num_tool_calls < self.max_tool_calls:
if self.process_tool_calls(tool_call_response):
self.num_tool_calls += 1
tool_responses = self.execute_tool_calls(tool_call_response)
# Add the assistant message with tool calls
# Converting to a dict so it can be safely dumped to json
self.cur_messages.append(tool_call_response.choices[0].message.to_dict())
# Add all tool responses
for tool_response in tool_responses:
self.cur_messages.append(tool_response)
return self.run(with_message="Continue", preproc=False)
elif self.num_tool_calls >= self.max_tool_calls:
self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.")
self.tool_call_limit = 0
self.num_tool_calls = 0
try:
if self.reply_completed():
return
@ -1553,39 +1542,80 @@ class Coder:
self.reflected_message = test_errors
return
def execute_tool_calls(self, tool_call_response):
def process_tool_calls(self, tool_call_response):
if tool_call_response is None:
return False
tool_calls = tool_call_response.choices[0].message.tool_calls
# Collect all tool calls grouped by server
server_tool_calls = self._gather_server_tool_calls(tool_calls)
if server_tool_calls and self.num_tool_calls < self.max_tool_calls:
self._print_tool_call_info(server_tool_calls)
if self.io.confirm_ask("Run tools?"):
tool_responses = self._execute_tool_calls(server_tool_calls)
# Add the assistant message with tool calls
# Converting to a dict so it can be safely dumped to json
self.cur_messages.append(tool_call_response.choices[0].message.to_dict())
# Add all tool responses
for tool_response in tool_responses:
self.cur_messages.append(tool_response)
return True
elif self.num_tool_calls >= self.max_tool_calls:
self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.")
return False
def _print_tool_call_info(self, server_tool_calls):
"""Print information about an MCP tool call."""
for server, tool_calls in server_tool_calls.items():
for tool_call in tool_calls:
self.io.tool_output(
f"Running MCP tool: {tool_call.function.name} from server {server.name}"
)
self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}")
if self.verbose:
self.io.tool_output(f"Tool ID: {tool_call.id}")
self.io.tool_output(f"Tool type: {tool_call.type}")
self.io.tool_output("\n")
def _gather_server_tool_calls(self, tool_calls):
"""Collect all tool calls grouped by server.
Args:
tool_calls: List of tool calls from the LLM response
Returns:
dict: Dictionary mapping servers to their respective tool calls
"""
if not self.mcp_tools or len(self.mcp_tools) == 0:
return None
server_tool_calls = {}
for tool_call in tool_calls:
# Check if this tool_call matches any MCP tool
for server_name, server_tools in self.mcp_tools:
for tool in server_tools:
if tool.get("function", {}).get("name") == tool_call.function.name:
# Find the McpServer instance that will be used for communication
for server in self.mcp_servers:
if server.name == server_name:
if server not in server_tool_calls:
server_tool_calls[server] = []
server_tool_calls[server].append(tool_call)
break
return server_tool_calls
def _execute_tool_calls(self, tool_calls):
"""Process tool calls from the response and execute them if they match MCP tools.
Returns a list of tool response messages."""
tool_responses = []
tool_calls = tool_call_response.choices[0].message.tool_calls
# First, collect all tool calls grouped by server
server_tool_calls = {}
for tool_call in tool_calls:
# Check if this tool_call matches any MCP tool
if self.mcp_tools:
for server_name, server_tools in self.mcp_tools:
for tool in server_tools:
if tool.get("function", {}).get("name") == tool_call.function.name:
self.io.tool_output(
f"Running MCP tool: {tool_call.function.name} from server"
f" {server_name}"
)
self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}")
if self.verbose:
self.io.tool_output(f"Tool ID: {tool_call.id}")
self.io.tool_output(f"Tool type: {tool_call.type}")
self.io.tool_output("\n")
# Find the corresponding server
for server in self.mcp_servers:
if server.name == server_name:
if server not in server_tool_calls:
server_tool_calls[server] = []
server_tool_calls[server].append(tool_call)
break
# Define the coroutine to execute all tool calls for a single server
async def _exec_server_tools(server, tool_calls_list):
@ -1610,14 +1640,14 @@ class Coder:
# Execute all tool calls concurrently
async def _execute_all_tool_calls():
tasks = []
for server, tool_calls_list in server_tool_calls.items():
for server, tool_calls_list in tool_calls.items():
tasks.append(_exec_server_tools(server, tool_calls_list))
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
return results
# Run the async execution and collect results
if server_tool_calls:
if tool_calls:
all_results = asyncio.run(_execute_all_tool_calls())
# Flatten the results from all servers
for server_results in all_results: