mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-05 12:14:59 +00:00
Respect Aider confirmation settings
This commit is contained in:
parent
282b349080
commit
097620026f
1 changed files with 77 additions and 47 deletions
|
@ -1478,25 +1478,14 @@ class Coder:
|
||||||
self.reflected_message = add_rel_files_message
|
self.reflected_message = add_rel_files_message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Process any tools using MCP servers
|
||||||
tool_call_response = litellm.stream_chunk_builder(self.partial_response_tool_call)
|
tool_call_response = litellm.stream_chunk_builder(self.partial_response_tool_call)
|
||||||
|
if self.process_tool_calls(tool_call_response):
|
||||||
if tool_call_response and self.num_tool_calls < self.max_tool_calls:
|
|
||||||
self.num_tool_calls += 1
|
self.num_tool_calls += 1
|
||||||
tool_responses = self.execute_tool_calls(tool_call_response)
|
|
||||||
|
|
||||||
# Add the assistant message with tool calls
|
|
||||||
# Converting to a dict so it can be safely dumped to json
|
|
||||||
self.cur_messages.append(tool_call_response.choices[0].message.to_dict())
|
|
||||||
|
|
||||||
# Add all tool responses
|
|
||||||
for tool_response in tool_responses:
|
|
||||||
self.cur_messages.append(tool_response)
|
|
||||||
|
|
||||||
return self.run(with_message="Continue", preproc=False)
|
return self.run(with_message="Continue", preproc=False)
|
||||||
elif self.num_tool_calls >= self.max_tool_calls:
|
|
||||||
self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.")
|
|
||||||
|
|
||||||
self.tool_call_limit = 0
|
self.num_tool_calls = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.reply_completed():
|
if self.reply_completed():
|
||||||
return
|
return
|
||||||
|
@ -1553,39 +1542,80 @@ class Coder:
|
||||||
self.reflected_message = test_errors
|
self.reflected_message = test_errors
|
||||||
return
|
return
|
||||||
|
|
||||||
def execute_tool_calls(self, tool_call_response):
|
def process_tool_calls(self, tool_call_response):
|
||||||
|
if tool_call_response is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
tool_calls = tool_call_response.choices[0].message.tool_calls
|
||||||
|
# Collect all tool calls grouped by server
|
||||||
|
server_tool_calls = self._gather_server_tool_calls(tool_calls)
|
||||||
|
|
||||||
|
if server_tool_calls and self.num_tool_calls < self.max_tool_calls:
|
||||||
|
self._print_tool_call_info(server_tool_calls)
|
||||||
|
|
||||||
|
if self.io.confirm_ask("Run tools?"):
|
||||||
|
tool_responses = self._execute_tool_calls(server_tool_calls)
|
||||||
|
|
||||||
|
# Add the assistant message with tool calls
|
||||||
|
# Converting to a dict so it can be safely dumped to json
|
||||||
|
self.cur_messages.append(tool_call_response.choices[0].message.to_dict())
|
||||||
|
|
||||||
|
# Add all tool responses
|
||||||
|
for tool_response in tool_responses:
|
||||||
|
self.cur_messages.append(tool_response)
|
||||||
|
|
||||||
|
return True
|
||||||
|
elif self.num_tool_calls >= self.max_tool_calls:
|
||||||
|
self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _print_tool_call_info(self, server_tool_calls):
|
||||||
|
"""Print information about an MCP tool call."""
|
||||||
|
|
||||||
|
for server, tool_calls in server_tool_calls.items():
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
self.io.tool_output(
|
||||||
|
f"Running MCP tool: {tool_call.function.name} from server {server.name}"
|
||||||
|
)
|
||||||
|
self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}")
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.io.tool_output(f"Tool ID: {tool_call.id}")
|
||||||
|
self.io.tool_output(f"Tool type: {tool_call.type}")
|
||||||
|
|
||||||
|
self.io.tool_output("\n")
|
||||||
|
|
||||||
|
def _gather_server_tool_calls(self, tool_calls):
|
||||||
|
"""Collect all tool calls grouped by server.
|
||||||
|
Args:
|
||||||
|
tool_calls: List of tool calls from the LLM response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary mapping servers to their respective tool calls
|
||||||
|
"""
|
||||||
|
if not self.mcp_tools or len(self.mcp_tools) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
server_tool_calls = {}
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
# Check if this tool_call matches any MCP tool
|
||||||
|
for server_name, server_tools in self.mcp_tools:
|
||||||
|
for tool in server_tools:
|
||||||
|
if tool.get("function", {}).get("name") == tool_call.function.name:
|
||||||
|
# Find the McpServer instance that will be used for communication
|
||||||
|
for server in self.mcp_servers:
|
||||||
|
if server.name == server_name:
|
||||||
|
if server not in server_tool_calls:
|
||||||
|
server_tool_calls[server] = []
|
||||||
|
server_tool_calls[server].append(tool_call)
|
||||||
|
break
|
||||||
|
|
||||||
|
return server_tool_calls
|
||||||
|
|
||||||
|
def _execute_tool_calls(self, tool_calls):
|
||||||
"""Process tool calls from the response and execute them if they match MCP tools.
|
"""Process tool calls from the response and execute them if they match MCP tools.
|
||||||
Returns a list of tool response messages."""
|
Returns a list of tool response messages."""
|
||||||
tool_responses = []
|
tool_responses = []
|
||||||
tool_calls = tool_call_response.choices[0].message.tool_calls
|
|
||||||
|
|
||||||
# First, collect all tool calls grouped by server
|
|
||||||
server_tool_calls = {}
|
|
||||||
|
|
||||||
for tool_call in tool_calls:
|
|
||||||
# Check if this tool_call matches any MCP tool
|
|
||||||
if self.mcp_tools:
|
|
||||||
for server_name, server_tools in self.mcp_tools:
|
|
||||||
for tool in server_tools:
|
|
||||||
if tool.get("function", {}).get("name") == tool_call.function.name:
|
|
||||||
self.io.tool_output(
|
|
||||||
f"Running MCP tool: {tool_call.function.name} from server"
|
|
||||||
f" {server_name}"
|
|
||||||
)
|
|
||||||
self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}")
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
self.io.tool_output(f"Tool ID: {tool_call.id}")
|
|
||||||
self.io.tool_output(f"Tool type: {tool_call.type}")
|
|
||||||
|
|
||||||
self.io.tool_output("\n")
|
|
||||||
# Find the corresponding server
|
|
||||||
for server in self.mcp_servers:
|
|
||||||
if server.name == server_name:
|
|
||||||
if server not in server_tool_calls:
|
|
||||||
server_tool_calls[server] = []
|
|
||||||
server_tool_calls[server].append(tool_call)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Define the coroutine to execute all tool calls for a single server
|
# Define the coroutine to execute all tool calls for a single server
|
||||||
async def _exec_server_tools(server, tool_calls_list):
|
async def _exec_server_tools(server, tool_calls_list):
|
||||||
|
@ -1610,14 +1640,14 @@ class Coder:
|
||||||
# Execute all tool calls concurrently
|
# Execute all tool calls concurrently
|
||||||
async def _execute_all_tool_calls():
|
async def _execute_all_tool_calls():
|
||||||
tasks = []
|
tasks = []
|
||||||
for server, tool_calls_list in server_tool_calls.items():
|
for server, tool_calls_list in tool_calls.items():
|
||||||
tasks.append(_exec_server_tools(server, tool_calls_list))
|
tasks.append(_exec_server_tools(server, tool_calls_list))
|
||||||
# Wait for all tasks to complete
|
# Wait for all tasks to complete
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# Run the async execution and collect results
|
# Run the async execution and collect results
|
||||||
if server_tool_calls:
|
if tool_calls:
|
||||||
all_results = asyncio.run(_execute_all_tool_calls())
|
all_results = asyncio.run(_execute_all_tool_calls())
|
||||||
# Flatten the results from all servers
|
# Flatten the results from all servers
|
||||||
for server_results in all_results:
|
for server_results in all_results:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue