diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 1d172bbce..028d29280 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -1478,25 +1478,14 @@ class Coder: self.reflected_message = add_rel_files_message return + # Process any tools using MCP servers tool_call_response = litellm.stream_chunk_builder(self.partial_response_tool_call) - - if tool_call_response and self.num_tool_calls < self.max_tool_calls: + if self.process_tool_calls(tool_call_response): self.num_tool_calls += 1 - tool_responses = self.execute_tool_calls(tool_call_response) - - # Add the assistant message with tool calls - # Converting to a dict so it can be safely dumped to json - self.cur_messages.append(tool_call_response.choices[0].message.to_dict()) - - # Add all tool responses - for tool_response in tool_responses: - self.cur_messages.append(tool_response) - return self.run(with_message="Continue", preproc=False) - elif self.num_tool_calls >= self.max_tool_calls: - self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.") - self.tool_call_limit = 0 + self.num_tool_calls = 0 + try: if self.reply_completed(): return @@ -1553,39 +1542,80 @@ class Coder: self.reflected_message = test_errors return - def execute_tool_calls(self, tool_call_response): + def process_tool_calls(self, tool_call_response): + if tool_call_response is None: + return False + + tool_calls = tool_call_response.choices[0].message.tool_calls + # Collect all tool calls grouped by server + server_tool_calls = self._gather_server_tool_calls(tool_calls) + + if server_tool_calls and self.num_tool_calls < self.max_tool_calls: + self._print_tool_call_info(server_tool_calls) + + if self.io.confirm_ask("Run tools?"): + tool_responses = self._execute_tool_calls(server_tool_calls) + + # Add the assistant message with tool calls + # Converting to a dict so it can be safely dumped to json + self.cur_messages.append(tool_call_response.choices[0].message.to_dict()) + + # Add all tool responses + for tool_response in tool_responses: + self.cur_messages.append(tool_response) + + return True + elif self.num_tool_calls >= self.max_tool_calls: + self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.") + return False + + def _print_tool_call_info(self, server_tool_calls): + """Print information about an MCP tool call.""" + + for server, tool_calls in server_tool_calls.items(): + for tool_call in tool_calls: + self.io.tool_output( + f"Running MCP tool: {tool_call.function.name} from server {server.name}" + ) + self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}") + + if self.verbose: + self.io.tool_output(f"Tool ID: {tool_call.id}") + self.io.tool_output(f"Tool type: {tool_call.type}") + + self.io.tool_output("\n") + + def _gather_server_tool_calls(self, tool_calls): + """Collect all tool calls grouped by server. + Args: + tool_calls: List of tool calls from the LLM response + + Returns: + dict: Dictionary mapping servers to their respective tool calls + """ + if not self.mcp_tools or len(self.mcp_tools) == 0: + return None + + server_tool_calls = {} + for tool_call in tool_calls: + # Check if this tool_call matches any MCP tool + for server_name, server_tools in self.mcp_tools: + for tool in server_tools: + if tool.get("function", {}).get("name") == tool_call.function.name: + # Find the McpServer instance that will be used for communication + for server in self.mcp_servers: + if server.name == server_name: + if server not in server_tool_calls: + server_tool_calls[server] = [] + server_tool_calls[server].append(tool_call) + break + + return server_tool_calls + + def _execute_tool_calls(self, tool_calls): """Process tool calls from the response and execute them if they match MCP tools. Returns a list of tool response messages.""" tool_responses = [] - tool_calls = tool_call_response.choices[0].message.tool_calls - - # First, collect all tool calls grouped by server - server_tool_calls = {} - - for tool_call in tool_calls: - # Check if this tool_call matches any MCP tool - if self.mcp_tools: - for server_name, server_tools in self.mcp_tools: - for tool in server_tools: - if tool.get("function", {}).get("name") == tool_call.function.name: - self.io.tool_output( - f"Running MCP tool: {tool_call.function.name} from server" - f" {server_name}" - ) - self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}") - - if self.verbose: - self.io.tool_output(f"Tool ID: {tool_call.id}") - self.io.tool_output(f"Tool type: {tool_call.type}") - - self.io.tool_output("\n") - # Find the corresponding server - for server in self.mcp_servers: - if server.name == server_name: - if server not in server_tool_calls: - server_tool_calls[server] = [] - server_tool_calls[server].append(tool_call) - break # Define the coroutine to execute all tool calls for a single server async def _exec_server_tools(server, tool_calls_list): @@ -1610,14 +1640,14 @@ class Coder: # Execute all tool calls concurrently async def _execute_all_tool_calls(): tasks = [] - for server, tool_calls_list in server_tool_calls.items(): + for server, tool_calls_list in tool_calls.items(): tasks.append(_exec_server_tools(server, tool_calls_list)) # Wait for all tasks to complete results = await asyncio.gather(*tasks) return results # Run the async execution and collect results - if server_tool_calls: + if tool_calls: all_results = asyncio.run(_execute_all_tool_calls()) # Flatten the results from all servers for server_results in all_results: