mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-03 19:24:59 +00:00
Respect Aider confirmation settings
This commit is contained in:
parent
282b349080
commit
097620026f
1 changed files with 77 additions and 47 deletions
|
@ -1478,25 +1478,14 @@ class Coder:
|
|||
self.reflected_message = add_rel_files_message
|
||||
return
|
||||
|
||||
# Process any tools using MCP servers
|
||||
tool_call_response = litellm.stream_chunk_builder(self.partial_response_tool_call)
|
||||
|
||||
if tool_call_response and self.num_tool_calls < self.max_tool_calls:
|
||||
if self.process_tool_calls(tool_call_response):
|
||||
self.num_tool_calls += 1
|
||||
tool_responses = self.execute_tool_calls(tool_call_response)
|
||||
|
||||
# Add the assistant message with tool calls
|
||||
# Converting to a dict so it can be safely dumped to json
|
||||
self.cur_messages.append(tool_call_response.choices[0].message.to_dict())
|
||||
|
||||
# Add all tool responses
|
||||
for tool_response in tool_responses:
|
||||
self.cur_messages.append(tool_response)
|
||||
|
||||
return self.run(with_message="Continue", preproc=False)
|
||||
elif self.num_tool_calls >= self.max_tool_calls:
|
||||
self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.")
|
||||
|
||||
self.tool_call_limit = 0
|
||||
self.num_tool_calls = 0
|
||||
|
||||
try:
|
||||
if self.reply_completed():
|
||||
return
|
||||
|
@ -1553,39 +1542,80 @@ class Coder:
|
|||
self.reflected_message = test_errors
|
||||
return
|
||||
|
||||
def execute_tool_calls(self, tool_call_response):
|
||||
def process_tool_calls(self, tool_call_response):
|
||||
if tool_call_response is None:
|
||||
return False
|
||||
|
||||
tool_calls = tool_call_response.choices[0].message.tool_calls
|
||||
# Collect all tool calls grouped by server
|
||||
server_tool_calls = self._gather_server_tool_calls(tool_calls)
|
||||
|
||||
if server_tool_calls and self.num_tool_calls < self.max_tool_calls:
|
||||
self._print_tool_call_info(server_tool_calls)
|
||||
|
||||
if self.io.confirm_ask("Run tools?"):
|
||||
tool_responses = self._execute_tool_calls(server_tool_calls)
|
||||
|
||||
# Add the assistant message with tool calls
|
||||
# Converting to a dict so it can be safely dumped to json
|
||||
self.cur_messages.append(tool_call_response.choices[0].message.to_dict())
|
||||
|
||||
# Add all tool responses
|
||||
for tool_response in tool_responses:
|
||||
self.cur_messages.append(tool_response)
|
||||
|
||||
return True
|
||||
elif self.num_tool_calls >= self.max_tool_calls:
|
||||
self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.")
|
||||
return False
|
||||
|
||||
def _print_tool_call_info(self, server_tool_calls):
|
||||
"""Print information about an MCP tool call."""
|
||||
|
||||
for server, tool_calls in server_tool_calls.items():
|
||||
for tool_call in tool_calls:
|
||||
self.io.tool_output(
|
||||
f"Running MCP tool: {tool_call.function.name} from server {server.name}"
|
||||
)
|
||||
self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}")
|
||||
|
||||
if self.verbose:
|
||||
self.io.tool_output(f"Tool ID: {tool_call.id}")
|
||||
self.io.tool_output(f"Tool type: {tool_call.type}")
|
||||
|
||||
self.io.tool_output("\n")
|
||||
|
||||
def _gather_server_tool_calls(self, tool_calls):
|
||||
"""Collect all tool calls grouped by server.
|
||||
Args:
|
||||
tool_calls: List of tool calls from the LLM response
|
||||
|
||||
Returns:
|
||||
dict: Dictionary mapping servers to their respective tool calls
|
||||
"""
|
||||
if not self.mcp_tools or len(self.mcp_tools) == 0:
|
||||
return None
|
||||
|
||||
server_tool_calls = {}
|
||||
for tool_call in tool_calls:
|
||||
# Check if this tool_call matches any MCP tool
|
||||
for server_name, server_tools in self.mcp_tools:
|
||||
for tool in server_tools:
|
||||
if tool.get("function", {}).get("name") == tool_call.function.name:
|
||||
# Find the McpServer instance that will be used for communication
|
||||
for server in self.mcp_servers:
|
||||
if server.name == server_name:
|
||||
if server not in server_tool_calls:
|
||||
server_tool_calls[server] = []
|
||||
server_tool_calls[server].append(tool_call)
|
||||
break
|
||||
|
||||
return server_tool_calls
|
||||
|
||||
def _execute_tool_calls(self, tool_calls):
|
||||
"""Process tool calls from the response and execute them if they match MCP tools.
|
||||
Returns a list of tool response messages."""
|
||||
tool_responses = []
|
||||
tool_calls = tool_call_response.choices[0].message.tool_calls
|
||||
|
||||
# First, collect all tool calls grouped by server
|
||||
server_tool_calls = {}
|
||||
|
||||
for tool_call in tool_calls:
|
||||
# Check if this tool_call matches any MCP tool
|
||||
if self.mcp_tools:
|
||||
for server_name, server_tools in self.mcp_tools:
|
||||
for tool in server_tools:
|
||||
if tool.get("function", {}).get("name") == tool_call.function.name:
|
||||
self.io.tool_output(
|
||||
f"Running MCP tool: {tool_call.function.name} from server"
|
||||
f" {server_name}"
|
||||
)
|
||||
self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}")
|
||||
|
||||
if self.verbose:
|
||||
self.io.tool_output(f"Tool ID: {tool_call.id}")
|
||||
self.io.tool_output(f"Tool type: {tool_call.type}")
|
||||
|
||||
self.io.tool_output("\n")
|
||||
# Find the corresponding server
|
||||
for server in self.mcp_servers:
|
||||
if server.name == server_name:
|
||||
if server not in server_tool_calls:
|
||||
server_tool_calls[server] = []
|
||||
server_tool_calls[server].append(tool_call)
|
||||
break
|
||||
|
||||
# Define the coroutine to execute all tool calls for a single server
|
||||
async def _exec_server_tools(server, tool_calls_list):
|
||||
|
@ -1610,14 +1640,14 @@ class Coder:
|
|||
# Execute all tool calls concurrently
|
||||
async def _execute_all_tool_calls():
|
||||
tasks = []
|
||||
for server, tool_calls_list in server_tool_calls.items():
|
||||
for server, tool_calls_list in tool_calls.items():
|
||||
tasks.append(_exec_server_tools(server, tool_calls_list))
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
# Run the async execution and collect results
|
||||
if server_tool_calls:
|
||||
if tool_calls:
|
||||
all_results = asyncio.run(_execute_all_tool_calls())
|
||||
# Flatten the results from all servers
|
||||
for server_results in all_results:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue