mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-30 17:24:59 +00:00
Merge c1a5e8d0d5
into 3caab85931
This commit is contained in:
commit
87ee7cceee
13 changed files with 930 additions and 13 deletions
|
@ -760,6 +760,18 @@ def get_parser(default_config_files, git_root):
|
|||
default="platform",
|
||||
help="Line endings to use when writing files (default: platform)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--mcp-servers",
|
||||
metavar="MCP_CONFIG_JSON",
|
||||
help="Specify MCP server configurations as a JSON string",
|
||||
default=None,
|
||||
)
|
||||
group.add_argument(
|
||||
"--mcp-servers-file",
|
||||
metavar="MCP_CONFIG_FILE",
|
||||
help="Specify a file path with MCP server configurations",
|
||||
default=None,
|
||||
)
|
||||
group.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
|
@ -26,6 +27,7 @@ from json.decoder import JSONDecodeError
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from litellm import experimental_mcp_client
|
||||
from rich.console import Console
|
||||
|
||||
from aider import __version__, models, prompts, urls, utils
|
||||
|
@ -99,6 +101,8 @@ class Coder:
|
|||
last_keyboard_interrupt = None
|
||||
num_reflections = 0
|
||||
max_reflections = 3
|
||||
num_tool_calls = 0
|
||||
max_tool_calls = 25
|
||||
edit_format = None
|
||||
yield_stream = False
|
||||
temperature = None
|
||||
|
@ -109,6 +113,7 @@ class Coder:
|
|||
test_outcome = None
|
||||
multi_response_content = ""
|
||||
partial_response_content = ""
|
||||
partial_response_tool_call = []
|
||||
commit_before_message = []
|
||||
message_cost = 0.0
|
||||
add_cache_headers = False
|
||||
|
@ -119,6 +124,8 @@ class Coder:
|
|||
ignore_mentions = None
|
||||
chat_language = None
|
||||
file_watcher = None
|
||||
mcp_servers = None
|
||||
mcp_tools = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
@ -335,6 +342,7 @@ class Coder:
|
|||
file_watcher=None,
|
||||
auto_copy_context=False,
|
||||
auto_accept_architect=True,
|
||||
mcp_servers=None,
|
||||
):
|
||||
# Fill in a dummy Analytics if needed, but it is never .enable()'d
|
||||
self.analytics = analytics if analytics is not None else Analytics()
|
||||
|
@ -361,6 +369,7 @@ class Coder:
|
|||
self.detect_urls = detect_urls
|
||||
|
||||
self.num_cache_warming_pings = num_cache_warming_pings
|
||||
self.mcp_servers = mcp_servers
|
||||
|
||||
if not fnames:
|
||||
fnames = []
|
||||
|
@ -525,6 +534,9 @@ class Coder:
|
|||
self.auto_test = auto_test
|
||||
self.test_cmd = test_cmd
|
||||
|
||||
# Instantiate MCP tools
|
||||
if self.mcp_servers:
|
||||
self.initialize_mcp_tools()
|
||||
# validate the functions jsonschema
|
||||
if self.functions:
|
||||
from jsonschema import Draft7Validator
|
||||
|
@ -667,7 +679,10 @@ class Coder:
|
|||
def get_cur_message_text(self):
|
||||
text = ""
|
||||
for msg in self.cur_messages:
|
||||
text += msg["content"] + "\n"
|
||||
# For some models the content is None if the message
|
||||
# contains tool calls.
|
||||
content = msg["content"] or ""
|
||||
text += content + "\n"
|
||||
return text
|
||||
|
||||
def get_ident_mentions(self, text):
|
||||
|
@ -1168,6 +1183,7 @@ class Coder:
|
|||
|
||||
def fmt_system_prompt(self, prompt):
|
||||
final_reminders = []
|
||||
|
||||
if self.main_model.lazy:
|
||||
final_reminders.append(self.gpt_prompts.lazy_prompt)
|
||||
if self.main_model.overeager:
|
||||
|
@ -1202,6 +1218,9 @@ class Coder:
|
|||
else:
|
||||
quad_backtick_reminder = ""
|
||||
|
||||
if self.mcp_tools and len(self.mcp_tools) > 0:
|
||||
final_reminders.append(self.gpt_prompts.tool_prompt)
|
||||
|
||||
final_reminders = "\n\n".join(final_reminders)
|
||||
|
||||
prompt = prompt.format(
|
||||
|
@ -1423,6 +1442,7 @@ class Coder:
|
|||
|
||||
chunks = self.format_messages()
|
||||
messages = chunks.all_messages()
|
||||
|
||||
if not self.check_tokens(messages):
|
||||
return
|
||||
self.warm_cache(chunks)
|
||||
|
@ -1561,6 +1581,14 @@ class Coder:
|
|||
self.reflected_message = add_rel_files_message
|
||||
return
|
||||
|
||||
# Process any tools using MCP servers
|
||||
tool_call_response = litellm.stream_chunk_builder(self.partial_response_tool_call)
|
||||
if self.process_tool_calls(tool_call_response):
|
||||
self.num_tool_calls += 1
|
||||
return self.run(with_message="Continue with tool call response", preproc=False)
|
||||
|
||||
self.num_tool_calls = 0
|
||||
|
||||
try:
|
||||
if self.reply_completed():
|
||||
return
|
||||
|
@ -1617,6 +1645,168 @@ class Coder:
|
|||
self.reflected_message = test_errors
|
||||
return
|
||||
|
||||
def process_tool_calls(self, tool_call_response):
|
||||
if tool_call_response is None:
|
||||
return False
|
||||
|
||||
tool_calls = tool_call_response.choices[0].message.tool_calls
|
||||
# Collect all tool calls grouped by server
|
||||
server_tool_calls = self._gather_server_tool_calls(tool_calls)
|
||||
|
||||
if server_tool_calls and self.num_tool_calls < self.max_tool_calls:
|
||||
self._print_tool_call_info(server_tool_calls)
|
||||
|
||||
if self.io.confirm_ask("Run tools?"):
|
||||
tool_responses = self._execute_tool_calls(server_tool_calls)
|
||||
|
||||
# Add the assistant message with tool calls
|
||||
# Converting to a dict so it can be safely dumped to json
|
||||
self.cur_messages.append(tool_call_response.choices[0].message.to_dict())
|
||||
|
||||
# Add all tool responses
|
||||
for tool_response in tool_responses:
|
||||
self.cur_messages.append(tool_response)
|
||||
|
||||
return True
|
||||
elif self.num_tool_calls >= self.max_tool_calls:
|
||||
self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.")
|
||||
return False
|
||||
|
||||
def _print_tool_call_info(self, server_tool_calls):
|
||||
"""Print information about an MCP tool call."""
|
||||
self.io.tool_output("Preparing to run MCP tools", bold=True)
|
||||
|
||||
for server, tool_calls in server_tool_calls.items():
|
||||
for tool_call in tool_calls:
|
||||
self.io.tool_output(f"Tool Call: {tool_call.function.name}")
|
||||
self.io.tool_output(f"Arguments: {tool_call.function.arguments}")
|
||||
self.io.tool_output(f"MCP Server: {server.name}")
|
||||
|
||||
if self.verbose:
|
||||
self.io.tool_output(f"Tool ID: {tool_call.id}")
|
||||
self.io.tool_output(f"Tool type: {tool_call.type}")
|
||||
|
||||
self.io.tool_output("\n")
|
||||
|
||||
def _gather_server_tool_calls(self, tool_calls):
|
||||
"""Collect all tool calls grouped by server.
|
||||
Args:
|
||||
tool_calls: List of tool calls from the LLM response
|
||||
|
||||
Returns:
|
||||
dict: Dictionary mapping servers to their respective tool calls
|
||||
"""
|
||||
if not self.mcp_tools or len(self.mcp_tools) == 0:
|
||||
return None
|
||||
|
||||
server_tool_calls = {}
|
||||
for tool_call in tool_calls:
|
||||
# Check if this tool_call matches any MCP tool
|
||||
for server_name, server_tools in self.mcp_tools:
|
||||
for tool in server_tools:
|
||||
if tool.get("function", {}).get("name") == tool_call.function.name:
|
||||
# Find the McpServer instance that will be used for communication
|
||||
for server in self.mcp_servers:
|
||||
if server.name == server_name:
|
||||
if server not in server_tool_calls:
|
||||
server_tool_calls[server] = []
|
||||
server_tool_calls[server].append(tool_call)
|
||||
break
|
||||
|
||||
return server_tool_calls
|
||||
|
||||
def _execute_tool_calls(self, tool_calls):
|
||||
"""Process tool calls from the response and execute them if they match MCP tools.
|
||||
Returns a list of tool response messages."""
|
||||
tool_responses = []
|
||||
|
||||
# Define the coroutine to execute all tool calls for a single server
|
||||
async def _exec_server_tools(server, tool_calls_list):
|
||||
tool_responses = []
|
||||
try:
|
||||
# Connect to the server once
|
||||
session = await server.connect()
|
||||
# Execute all tool calls for this server
|
||||
for tool_call in tool_calls_list:
|
||||
call_result = await experimental_mcp_client.call_openai_tool(
|
||||
session=session,
|
||||
openai_tool=tool_call,
|
||||
)
|
||||
result_text = str(call_result.content[0].text)
|
||||
tool_responses.append(
|
||||
{"role": "tool", "tool_call_id": tool_call.id, "content": result_text}
|
||||
)
|
||||
finally:
|
||||
await server.disconnect()
|
||||
return tool_responses
|
||||
|
||||
# Execute all tool calls concurrently
|
||||
async def _execute_all_tool_calls():
|
||||
tasks = []
|
||||
for server, tool_calls_list in tool_calls.items():
|
||||
tasks.append(_exec_server_tools(server, tool_calls_list))
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
# Run the async execution and collect results
|
||||
if tool_calls:
|
||||
all_results = asyncio.run(_execute_all_tool_calls())
|
||||
# Flatten the results from all servers
|
||||
for server_results in all_results:
|
||||
tool_responses.extend(server_results)
|
||||
|
||||
return tool_responses
|
||||
|
||||
def initialize_mcp_tools(self):
|
||||
"""
|
||||
Initialize tools from all configured MCP servers. MCP Servers that fail to be
|
||||
initialized will not be available to the Coder instance.
|
||||
"""
|
||||
tools = []
|
||||
|
||||
async def get_server_tools(server):
|
||||
try:
|
||||
session = await server.connect()
|
||||
server_tools = await experimental_mcp_client.load_mcp_tools(
|
||||
session=session, format="openai"
|
||||
)
|
||||
return (server.name, server_tools)
|
||||
except Exception as e:
|
||||
self.io.tool_warning(f"Error initializing MCP server {server.name}:\n{e}")
|
||||
return None
|
||||
finally:
|
||||
await server.disconnect()
|
||||
|
||||
async def get_all_server_tools():
|
||||
tasks = [get_server_tools(server) for server in self.mcp_servers]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return [result for result in results if result is not None]
|
||||
|
||||
if self.mcp_servers:
|
||||
tools = asyncio.run(get_all_server_tools())
|
||||
|
||||
if len(tools) > 0:
|
||||
self.io.tool_output("MCP servers configured:")
|
||||
for server_name, server_tools in tools:
|
||||
self.io.tool_output(f" - {server_name}")
|
||||
|
||||
if self.verbose:
|
||||
for tool in server_tools:
|
||||
tool_name = tool.get("function", {}).get("name", "unknown")
|
||||
tool_desc = tool.get("function", {}).get("description", "").split("\n")[0]
|
||||
self.io.tool_output(f" - {tool_name}: {tool_desc}")
|
||||
|
||||
self.mcp_tools = tools
|
||||
|
||||
def get_tool_list(self):
|
||||
"""Get a flattened list of all MCP tools."""
|
||||
tool_list = []
|
||||
if self.mcp_tools:
|
||||
for _, server_tools in self.mcp_tools:
|
||||
tool_list.extend(server_tools)
|
||||
return tool_list
|
||||
|
||||
def reply_completed(self):
|
||||
pass
|
||||
|
||||
|
@ -1788,12 +1978,17 @@ class Coder:
|
|||
self.io.log_llm_history("TO LLM", format_messages(messages))
|
||||
|
||||
completion = None
|
||||
|
||||
try:
|
||||
tool_list = self.get_tool_list()
|
||||
|
||||
hash_object, completion = model.send_completion(
|
||||
messages,
|
||||
functions,
|
||||
self.stream,
|
||||
self.temperature,
|
||||
# This could include any tools, but for now it is just MCP tools
|
||||
tools=tool_list,
|
||||
)
|
||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||
|
||||
|
@ -1894,6 +2089,7 @@ class Coder:
|
|||
|
||||
def show_send_output_stream(self, completion):
|
||||
received_content = False
|
||||
self.partial_response_tool_call = []
|
||||
|
||||
for chunk in completion:
|
||||
if len(chunk.choices) == 0:
|
||||
|
@ -1905,6 +2101,9 @@ class Coder:
|
|||
):
|
||||
raise FinishReasonLength()
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
self.partial_response_tool_call.append(chunk)
|
||||
|
||||
try:
|
||||
func = chunk.choices[0].delta.function_call
|
||||
# dump(func)
|
||||
|
@ -1913,6 +2112,7 @@ class Coder:
|
|||
self.partial_response_function_call[k] += v
|
||||
else:
|
||||
self.partial_response_function_call[k] = v
|
||||
|
||||
received_content = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
@ -1966,7 +2166,7 @@ class Coder:
|
|||
sys.stdout.flush()
|
||||
yield text
|
||||
|
||||
if not received_content:
|
||||
if not received_content and len(self.partial_response_tool_call) == 0:
|
||||
self.io.tool_warning("Empty response received from LLM. Check your provider account?")
|
||||
|
||||
def live_incremental_response(self, final):
|
||||
|
|
|
@ -56,5 +56,18 @@ Do not edit these files!
|
|||
no_shell_cmd_prompt = ""
|
||||
no_shell_cmd_reminder = ""
|
||||
|
||||
tool_prompt = """
|
||||
<tool_calling>
|
||||
When solving problems, you have special tools available. Please follow these rules:
|
||||
|
||||
1. Always use the exact format required for each tool and include all needed information.
|
||||
2. Only use tools that are currently available in this conversation.
|
||||
3. Don't mention tool names when talking to people. Say "I'll check your code" instead
|
||||
of "I'll use the code_analyzer tool."
|
||||
4. Only use tools when necessary. If you know the answer, just respond directly.
|
||||
5. Before using any tool, briefly explain why you need to use it.
|
||||
</tool_calling>
|
||||
"""
|
||||
|
||||
rename_with_shell = ""
|
||||
go_ahead_tip = ""
|
||||
|
|
|
@ -30,6 +30,7 @@ from aider.format_settings import format_settings, scrub_sensitive_info
|
|||
from aider.history import ChatSummary
|
||||
from aider.io import InputOutput
|
||||
from aider.llm import litellm # noqa: F401; properly init litellm on launch
|
||||
from aider.mcp import load_mcp_servers
|
||||
from aider.models import ModelSettings
|
||||
from aider.onboarding import offer_openrouter_oauth, select_default_model
|
||||
from aider.repo import ANY_GIT_ERROR, GitRepo
|
||||
|
@ -964,6 +965,12 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
|
|||
analytics.event("auto_commits", enabled=bool(args.auto_commits))
|
||||
|
||||
try:
|
||||
# Load MCP servers from config string or file
|
||||
mcp_servers = load_mcp_servers(args.mcp_servers, args.mcp_servers_file, io, args.verbose)
|
||||
|
||||
if not mcp_servers:
|
||||
mcp_servers = []
|
||||
|
||||
coder = Coder.create(
|
||||
main_model=main_model,
|
||||
edit_format=args.edit_format,
|
||||
|
@ -996,6 +1003,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
|
|||
detect_urls=args.detect_urls,
|
||||
auto_copy_context=args.copy_paste,
|
||||
auto_accept_architect=args.auto_accept_architect,
|
||||
mcp_servers=mcp_servers,
|
||||
)
|
||||
except UnknownEditFormat as err:
|
||||
io.tool_error(str(err))
|
||||
|
|
86
aider/mcp/__init__.py
Normal file
86
aider/mcp/__init__.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import json
|
||||
|
||||
from aider.mcp.server import McpServer
|
||||
|
||||
|
||||
def _parse_mcp_servers_from_json_string(json_string, io, verbose=False):
|
||||
"""Parse MCP servers from a JSON string."""
|
||||
servers = []
|
||||
|
||||
try:
|
||||
config = json.loads(json_string)
|
||||
if verbose:
|
||||
io.tool_output("Loading MCP servers from provided JSON string")
|
||||
|
||||
if "mcpServers" in config:
|
||||
for name, server_config in config["mcpServers"].items():
|
||||
if verbose:
|
||||
io.tool_output(f"Loading MCP server: {name}")
|
||||
|
||||
# Create a server config with name included
|
||||
server_config["name"] = name
|
||||
servers.append(McpServer(server_config))
|
||||
|
||||
if verbose:
|
||||
io.tool_output(f"Loaded {len(servers)} MCP servers from JSON string")
|
||||
return servers
|
||||
else:
|
||||
io.tool_warning("No 'mcpServers' key found in MCP config JSON string")
|
||||
except json.JSONDecodeError:
|
||||
io.tool_error("Invalid JSON in MCP config string")
|
||||
except Exception as e:
|
||||
io.tool_error(f"Error loading MCP config from string: {e}")
|
||||
|
||||
return servers
|
||||
|
||||
|
||||
def _parse_mcp_servers_from_file(file_path, io, verbose=False):
|
||||
"""Parse MCP servers from a JSON file."""
|
||||
servers = []
|
||||
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
if verbose:
|
||||
io.tool_output(f"Loading MCP servers from file: {file_path}")
|
||||
|
||||
if "mcpServers" in config:
|
||||
for name, server_config in config["mcpServers"].items():
|
||||
if verbose:
|
||||
io.tool_output(f"Loading MCP server: {name}")
|
||||
|
||||
# Create a server config with name included
|
||||
server_config["name"] = name
|
||||
servers.append(McpServer(server_config))
|
||||
|
||||
if verbose:
|
||||
io.tool_output(f"Loaded {len(servers)} MCP servers from {file_path}")
|
||||
return servers
|
||||
else:
|
||||
io.tool_warning(f"No 'mcpServers' key found in MCP config file: {file_path}")
|
||||
except FileNotFoundError:
|
||||
io.tool_warning(f"MCP config file not found: {file_path}")
|
||||
except json.JSONDecodeError:
|
||||
io.tool_error(f"Invalid JSON in MCP config file: {file_path}")
|
||||
except Exception as e:
|
||||
io.tool_error(f"Error loading MCP config from file: {e}")
|
||||
|
||||
return servers
|
||||
|
||||
|
||||
def load_mcp_servers(mcp_servers, mcp_servers_file, io, verbose=False):
|
||||
"""Load MCP servers from a JSON string or file."""
|
||||
servers = []
|
||||
|
||||
# First try to load from the JSON string (preferred)
|
||||
if mcp_servers:
|
||||
servers = _parse_mcp_servers_from_json_string(mcp_servers, io, verbose)
|
||||
if servers:
|
||||
return servers
|
||||
|
||||
# If JSON string failed or wasn't provided, try the file
|
||||
if mcp_servers_file:
|
||||
servers = _parse_mcp_servers_from_file(mcp_servers_file, io, verbose)
|
||||
|
||||
return servers
|
76
aider/mcp/server.py
Normal file
76
aider/mcp/server.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
|
||||
class McpServer:
|
||||
"""
|
||||
A client for MCP servers that provides tools to Aider coders. An McpServer class
|
||||
is initialized per configured MCP Server
|
||||
|
||||
Current usage:
|
||||
|
||||
conn = await session.connect() # Use connect() directly
|
||||
tools = await experimental_mcp_client.load_mcp_tools(session=s, format="openai")
|
||||
await session.disconnect()
|
||||
print(tools)
|
||||
"""
|
||||
|
||||
def __init__(self, server_config):
|
||||
"""Initialize the MCP tool provider.
|
||||
|
||||
Args:
|
||||
server_config: Configuration for the MCP server
|
||||
"""
|
||||
self.config = server_config
|
||||
self.name = server_config.get("name", "unnamed-server")
|
||||
self.session = None
|
||||
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the MCP server and return the session.
|
||||
|
||||
If a session is already active, returns the existing session.
|
||||
Otherwise, establishes a new connection and initializes the session.
|
||||
|
||||
Returns:
|
||||
ClientSession: The active session
|
||||
"""
|
||||
if self.session is not None:
|
||||
logging.info(f"Using existing session for MCP server: {self.name}")
|
||||
return self.session
|
||||
|
||||
logging.info(f"Establishing new connection to MCP server: {self.name}")
|
||||
command = self.config["command"]
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=self.config.get("args"),
|
||||
env={**os.environ, **self.config["env"]} if self.config.get("env") else None,
|
||||
)
|
||||
|
||||
try:
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
read, write = stdio_transport
|
||||
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
return session
|
||||
except Exception as e:
|
||||
logging.error(f"Error initializing server {self.name}: {e}")
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the MCP server and clean up resources."""
|
||||
async with self._cleanup_lock:
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
self.session = None
|
||||
self.stdio_context = None
|
||||
except Exception as e:
|
||||
logging.error(f"Error during cleanup of server {self.name}: {e}")
|
|
@ -873,17 +873,14 @@ class Model(ModelSettings):
|
|||
def is_ollama(self):
|
||||
return self.name.startswith("ollama/") or self.name.startswith("ollama_chat/")
|
||||
|
||||
def send_completion(self, messages, functions, stream, temperature=None):
|
||||
def send_completion(self, messages, functions, stream, temperature=None, tools=None):
|
||||
if os.environ.get("AIDER_SANITY_CHECK_TURNS"):
|
||||
sanity_check_messages(messages)
|
||||
|
||||
if self.is_deepseek_r1():
|
||||
messages = ensure_alternating_roles(messages)
|
||||
|
||||
kwargs = dict(
|
||||
model=self.name,
|
||||
stream=stream,
|
||||
)
|
||||
kwargs = dict(model=self.name, stream=stream, tools=[])
|
||||
|
||||
if self.use_temperature is not False:
|
||||
if temperature is None:
|
||||
|
@ -903,8 +900,11 @@ class Model(ModelSettings):
|
|||
if self.is_ollama() and "num_ctx" not in kwargs:
|
||||
num_ctx = int(self.token_count(messages) * 1.25) + 8192
|
||||
kwargs["num_ctx"] = num_ctx
|
||||
key = json.dumps(kwargs, sort_keys=True).encode()
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = kwargs["tools"] + tools
|
||||
|
||||
key = json.dumps(kwargs, sort_keys=True).encode()
|
||||
# dump(kwargs)
|
||||
|
||||
hash_object = hashlib.sha1(key)
|
||||
|
|
86
aider/website/docs/config/mcp.md
Normal file
86
aider/website/docs/config/mcp.md
Normal file
|
@ -0,0 +1,86 @@
|
|||
---
|
||||
parent: Configuration
|
||||
nav_order: 120
|
||||
description: Configure Model Control Protocol (MCP) servers for enhanced AI capabilities.
|
||||
---
|
||||
|
||||
# Model Control Protocol (MCP)
|
||||
|
||||
Model Control Protocol (MCP) servers extend aider's capabilities by providing additional tools and functionality to the AI models. MCP servers can add features like git operations, context retrieval, and other specialized tools.
|
||||
|
||||
## Configuring MCP Servers
|
||||
|
||||
Aider supports configuring MCP servers using the MCP Server Configuration schema. Please
|
||||
see the [Model Context Protocol documentation](https://modelcontextprotocol.io/introduction)
|
||||
for more information.
|
||||
|
||||
You have two ways of sharing your MCP server configuration with Aider.
|
||||
|
||||
{: .note }
|
||||
Today, Aider only supports connecting to MCP servers using the stdio transport
|
||||
|
||||
### Config Files
|
||||
|
||||
You can also configure MCP servers in your `.aider.conf.yml` file:
|
||||
|
||||
```yaml
|
||||
mcp-servers: |
|
||||
{
|
||||
"mcpServers": {
|
||||
"git": {
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-git"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Or specify a configuration file:
|
||||
|
||||
```yaml
|
||||
mcp-servers-file: /path/to/mcp.json
|
||||
```
|
||||
|
||||
These options are configurable in any of Aider's config file formats.
|
||||
|
||||
### Flags
|
||||
|
||||
You can specify MCP servers directly on the command line using the `--mcp-servers` option with a JSON string:
|
||||
|
||||
#### Using a JSON String
|
||||
|
||||
```bash
|
||||
aider --mcp-servers '{"mcpServers":{"git":{"command":"uvx","args":["mcp-server-git"]}}}'
|
||||
```
|
||||
|
||||
#### Using a configuration file
|
||||
|
||||
Alternatively, you can store your MCP server configurations in a JSON file and reference it with the `--mcp-servers-file` option:
|
||||
|
||||
```bash
|
||||
aider --mcp-servers-file mcp.json
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
You can also configure MCP servers using environment variables in your `.env` file:
|
||||
|
||||
```
|
||||
AIDER_MCP_SERVERS={"mcpServers":{"git":{"command":"uvx","args":["mcp-server-git"]}}}
|
||||
```
|
||||
|
||||
Or specify a configuration file:
|
||||
|
||||
```
|
||||
AIDER_MCP_SERVERS_FILE=/path/to/mcp.json
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you encounter issues with MCP servers:
|
||||
|
||||
1. Use the `--verbose` flag to see detailed information about MCP server loading
|
||||
2. Check that the specified executables are installed and available in your PATH
|
||||
3. Verify that your JSON configuration is valid
|
||||
|
||||
For more information about specific MCP servers and their capabilities, refer to their respective documentation.
|
|
@ -20,7 +20,10 @@ anyio==4.9.0
|
|||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# httpx
|
||||
# mcp
|
||||
# openai
|
||||
# sse-starlette
|
||||
# starlette
|
||||
# watchfiles
|
||||
attrs==25.3.0
|
||||
# via
|
||||
|
@ -60,6 +63,7 @@ click==8.1.8
|
|||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# litellm
|
||||
# uvicorn
|
||||
configargparse==1.7
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
@ -154,6 +158,7 @@ h11==0.16.0
|
|||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hf-xet==1.1.0
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
@ -171,7 +176,12 @@ httpx==0.28.1
|
|||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# litellm
|
||||
# mcp
|
||||
# openai
|
||||
httpx-sse==0.4.0
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# mcp
|
||||
huggingface-hub==0.31.1
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
@ -229,6 +239,10 @@ mccabe==0.7.0
|
|||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# flake8
|
||||
mcp==1.6.0
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# -r requirements/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
@ -338,11 +352,17 @@ pydantic==2.11.4
|
|||
# -c requirements/common-constraints.txt
|
||||
# google-generativeai
|
||||
# litellm
|
||||
# mcp
|
||||
# openai
|
||||
# pydantic-settings
|
||||
pydantic-core==2.33.2
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# pydantic
|
||||
pydantic-settings==2.9.1
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# mcp
|
||||
pydub==0.25.1
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
@ -375,6 +395,7 @@ python-dotenv==1.1.0
|
|||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# litellm
|
||||
# pydantic-settings
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
@ -449,6 +470,15 @@ soupsieve==2.7
|
|||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# beautifulsoup4
|
||||
sse-starlette==2.3.3
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# mcp
|
||||
starlette==0.46.2
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# mcp
|
||||
# sse-starlette
|
||||
tiktoken==0.9.0
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
@ -498,6 +528,7 @@ typing-inspection==0.4.0
|
|||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
uritemplate==4.1.1
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
@ -507,6 +538,10 @@ urllib3==2.4.0
|
|||
# -c requirements/common-constraints.txt
|
||||
# mixpanel
|
||||
# requests
|
||||
uvicorn==0.34.2
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
# mcp
|
||||
watchfiles==1.0.5
|
||||
# via
|
||||
# -c requirements/common-constraints.txt
|
||||
|
|
|
@ -16,7 +16,10 @@ annotated-types==0.7.0
|
|||
anyio==4.9.0
|
||||
# via
|
||||
# httpx
|
||||
# mcp
|
||||
# openai
|
||||
# sse-starlette
|
||||
# starlette
|
||||
# watchfiles
|
||||
attrs==25.3.0
|
||||
# via
|
||||
|
@ -59,6 +62,7 @@ click==8.1.8
|
|||
# pip-tools
|
||||
# streamlit
|
||||
# typer
|
||||
# uvicorn
|
||||
codespell==2.4.1
|
||||
# via -r requirements/requirements-dev.in
|
||||
cogapp==3.4.1
|
||||
|
@ -171,7 +175,9 @@ grpcio==1.71.0
|
|||
grpcio-status==1.71.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via httpcore
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hf-xet==1.1.0
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
|
@ -184,8 +190,11 @@ httpx==0.28.1
|
|||
# via
|
||||
# litellm
|
||||
# llama-index-core
|
||||
# mcp
|
||||
# openai
|
||||
huggingface-hub[inference]==0.31.1
|
||||
httpx-sse==0.4.0
|
||||
# via mcp
|
||||
huggingface-hub[inference]==0.30.2
|
||||
# via
|
||||
# llama-index-embeddings-huggingface
|
||||
# sentence-transformers
|
||||
|
@ -253,6 +262,8 @@ matplotlib==3.10.3
|
|||
# via -r requirements/requirements-dev.in
|
||||
mccabe==0.7.0
|
||||
# via flake8
|
||||
mcp==1.6.0
|
||||
# via -r requirements/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mixpanel==4.10.1
|
||||
|
@ -391,9 +402,13 @@ pydantic==2.11.4
|
|||
# google-generativeai
|
||||
# litellm
|
||||
# llama-index-core
|
||||
# mcp
|
||||
# openai
|
||||
# pydantic-settings
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.9.1
|
||||
# via mcp
|
||||
pydeck==0.9.1
|
||||
# via streamlit
|
||||
pydub==0.25.1
|
||||
|
@ -429,7 +444,9 @@ python-dateutil==2.9.0.post0
|
|||
# pandas
|
||||
# posthog
|
||||
python-dotenv==1.1.0
|
||||
# via litellm
|
||||
# via
|
||||
# litellm
|
||||
# pydantic-settings
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
|
@ -509,6 +526,12 @@ soupsieve==2.7
|
|||
# via beautifulsoup4
|
||||
sqlalchemy[asyncio]==2.0.40
|
||||
# via llama-index-core
|
||||
sse-starlette==2.3.3
|
||||
# via mcp
|
||||
starlette==0.46.2
|
||||
# via
|
||||
# mcp
|
||||
# sse-starlette
|
||||
streamlit==1.45.0
|
||||
# via -r requirements/requirements-browser.in
|
||||
sympy==1.14.0
|
||||
|
@ -583,7 +606,9 @@ typing-inspect==0.9.0
|
|||
# dataclasses-json
|
||||
# llama-index-core
|
||||
typing-inspection==0.4.0
|
||||
# via pydantic
|
||||
# via
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
uritemplate==4.1.1
|
||||
|
@ -594,6 +619,8 @@ urllib3==2.4.0
|
|||
# requests
|
||||
uv==0.7.3
|
||||
# via -r requirements/requirements-dev.in
|
||||
uvicorn==0.34.2
|
||||
# via mcp
|
||||
virtualenv==20.31.2
|
||||
# via pre-commit
|
||||
watchfiles==1.0.5
|
||||
|
|
|
@ -30,6 +30,7 @@ pillow
|
|||
shtab
|
||||
oslex
|
||||
google-generativeai
|
||||
mcp>=1.0.0
|
||||
|
||||
# The proper dependency is networkx[default], but this brings
|
||||
# in matplotlib and a bunch of other deps
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import git
|
||||
|
||||
|
@ -575,6 +575,7 @@ Once I have these, I can show you precisely how to do the thing.
|
|||
fname = Path("file.txt")
|
||||
|
||||
io = InputOutput(yes=True)
|
||||
io.tool_warning = MagicMock()
|
||||
coder = Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)])
|
||||
|
||||
self.assertTrue(fname.exists())
|
||||
|
@ -1433,6 +1434,324 @@ This command will print 'Hello, World!' to the console."""
|
|||
# (because user rejected the changes)
|
||||
mock_editor.run.assert_not_called()
|
||||
|
||||
@patch("aider.coders.base_coder.experimental_mcp_client")
|
||||
def test_mcp_server_connection(self, mock_mcp_client):
|
||||
"""Test that the coder connects to MCP servers for tools."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
|
||||
# Create mock MCP server
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = "test_server"
|
||||
mock_server.connect = MagicMock()
|
||||
mock_server.disconnect = MagicMock()
|
||||
|
||||
# Setup mock for initialize_mcp_tools
|
||||
mock_tools = [("test_server", [{"function": {"name": "test_tool"}}])]
|
||||
|
||||
# Create coder with mock MCP server
|
||||
with patch.object(Coder, "initialize_mcp_tools", return_value=mock_tools):
|
||||
coder = Coder.create(self.GPT35, "diff", io=io, mcp_servers=[mock_server])
|
||||
|
||||
# Manually set mcp_tools since we're bypassing initialize_mcp_tools
|
||||
coder.mcp_tools = mock_tools
|
||||
|
||||
# Verify that mcp_tools contains the expected data
|
||||
self.assertIsNotNone(coder.mcp_tools)
|
||||
self.assertEqual(len(coder.mcp_tools), 1)
|
||||
self.assertEqual(coder.mcp_tools[0][0], "test_server")
|
||||
|
||||
@patch("aider.coders.base_coder.experimental_mcp_client")
|
||||
def test_coder_creation_with_partial_failed_mcp_server(self, mock_mcp_client):
|
||||
"""Test that a coder can still be created even if an MCP server fails to initialize."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
io.tool_warning = MagicMock()
|
||||
|
||||
# Create mock MCP servers - one working, one failing
|
||||
working_server = AsyncMock()
|
||||
working_server.name = "working_server"
|
||||
working_server.connect = AsyncMock()
|
||||
working_server.disconnect = AsyncMock()
|
||||
|
||||
failing_server = AsyncMock()
|
||||
failing_server.name = "failing_server"
|
||||
failing_server.connect = AsyncMock()
|
||||
failing_server.disconnect = AsyncMock()
|
||||
|
||||
# Mock load_mcp_tools to succeed for working_server and fail for failing_server
|
||||
async def mock_load_mcp_tools(session, format):
|
||||
if session == await working_server.connect():
|
||||
return [{"function": {"name": "working_tool"}}]
|
||||
else:
|
||||
raise Exception("Failed to load tools")
|
||||
|
||||
mock_mcp_client.load_mcp_tools = AsyncMock(side_effect=mock_load_mcp_tools)
|
||||
|
||||
# Create coder with both servers
|
||||
coder = Coder.create(
|
||||
self.GPT35,
|
||||
"diff",
|
||||
io=io,
|
||||
mcp_servers=[working_server, failing_server],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Verify that coder was created successfully
|
||||
self.assertIsInstance(coder, Coder)
|
||||
|
||||
# Verify that only the working server's tools were added
|
||||
self.assertIsNotNone(coder.mcp_tools)
|
||||
self.assertEqual(len(coder.mcp_tools), 1)
|
||||
self.assertEqual(coder.mcp_tools[0][0], "working_server")
|
||||
|
||||
# Verify that the tool list contains only working tools
|
||||
tool_list = coder.get_tool_list()
|
||||
self.assertEqual(len(tool_list), 1)
|
||||
self.assertEqual(tool_list[0]["function"]["name"], "working_tool")
|
||||
|
||||
# Verify that the warning was logged for the failing server
|
||||
io.tool_warning.assert_called_with(
|
||||
"Error initializing MCP server failing_server:\nFailed to load tools"
|
||||
)
|
||||
|
||||
@patch("aider.coders.base_coder.experimental_mcp_client")
|
||||
def test_coder_creation_with_all_failed_mcp_server(self, mock_mcp_client):
|
||||
"""Test that a coder can still be created even if an MCP server fails to initialize."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
io.tool_warning = MagicMock()
|
||||
|
||||
failing_server = AsyncMock()
|
||||
failing_server.name = "failing_server"
|
||||
failing_server.connect = AsyncMock()
|
||||
failing_server.disconnect = AsyncMock()
|
||||
|
||||
# Mock load_mcp_tools to succeed for working_server and fail for failing_server
|
||||
async def mock_load_mcp_tools(session, format):
|
||||
raise Exception("Failed to load tools")
|
||||
|
||||
mock_mcp_client.load_mcp_tools = AsyncMock(side_effect=mock_load_mcp_tools)
|
||||
|
||||
# Create coder with both servers
|
||||
coder = Coder.create(
|
||||
self.GPT35,
|
||||
"diff",
|
||||
io=io,
|
||||
mcp_servers=[failing_server],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Verify that coder was created successfully
|
||||
self.assertIsInstance(coder, Coder)
|
||||
|
||||
# Verify that only the working server's tools were added
|
||||
self.assertIsNotNone(coder.mcp_tools)
|
||||
self.assertEqual(len(coder.mcp_tools), 0)
|
||||
|
||||
# Verify that the tool list contains only working tools
|
||||
tool_list = coder.get_tool_list()
|
||||
self.assertEqual(len(tool_list), 0)
|
||||
|
||||
# Verify that the warning was logged for the failing server
|
||||
io.tool_warning.assert_called_with(
|
||||
"Error initializing MCP server failing_server:\nFailed to load tools"
|
||||
)
|
||||
|
||||
def test_process_tool_calls_none_response(self):
|
||||
"""Test that process_tool_calls handles None response correctly."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Test with None response
|
||||
result = coder.process_tool_calls(None)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_process_tool_calls_no_tool_calls(self):
|
||||
"""Test that process_tool_calls handles response with no tool calls."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Create a response with no tool calls
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message = MagicMock()
|
||||
response.choices[0].message.tool_calls = []
|
||||
|
||||
result = coder.process_tool_calls(response)
|
||||
self.assertFalse(result)
|
||||
|
||||
@patch("aider.coders.base_coder.experimental_mcp_client")
|
||||
@patch("asyncio.run")
|
||||
def test_process_tool_calls_with_tools(self, mock_asyncio_run, mock_mcp_client):
|
||||
"""Test that process_tool_calls processes tool calls correctly."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
io.confirm_ask = MagicMock(return_value=True)
|
||||
|
||||
# Create mock MCP server
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = "test_server"
|
||||
|
||||
# Create a tool call
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "test_id"
|
||||
tool_call.type = "function"
|
||||
tool_call.function = MagicMock()
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = '{"param": "value"}'
|
||||
|
||||
# Create a response with tool calls
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message = MagicMock()
|
||||
response.choices[0].message.tool_calls = [tool_call]
|
||||
response.choices[0].message.to_dict = MagicMock(
|
||||
return_value={"role": "assistant", "tool_calls": [{"id": "test_id"}]}
|
||||
)
|
||||
|
||||
# Create coder with mock MCP tools and servers
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])]
|
||||
coder.mcp_servers = [mock_server]
|
||||
|
||||
# Mock asyncio.run to return tool responses
|
||||
tool_responses = [
|
||||
[{"role": "tool", "tool_call_id": "test_id", "content": "Tool execution result"}]
|
||||
]
|
||||
mock_asyncio_run.return_value = tool_responses
|
||||
|
||||
# Test process_tool_calls
|
||||
result = coder.process_tool_calls(response)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify that asyncio.run was called
|
||||
mock_asyncio_run.assert_called_once()
|
||||
|
||||
# Verify that the messages were added
|
||||
self.assertEqual(len(coder.cur_messages), 2)
|
||||
self.assertEqual(coder.cur_messages[0]["role"], "assistant")
|
||||
self.assertEqual(coder.cur_messages[1]["role"], "tool")
|
||||
self.assertEqual(coder.cur_messages[1]["tool_call_id"], "test_id")
|
||||
self.assertEqual(coder.cur_messages[1]["content"], "Tool execution result")
|
||||
|
||||
def test_process_tool_calls_max_calls_exceeded(self):
|
||||
"""Test that process_tool_calls handles max tool calls exceeded."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
io.tool_warning = MagicMock()
|
||||
|
||||
# Create a tool call
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "test_id"
|
||||
tool_call.type = "function"
|
||||
tool_call.function = MagicMock()
|
||||
tool_call.function.name = "test_tool"
|
||||
|
||||
# Create a response with tool calls
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message = MagicMock()
|
||||
response.choices[0].message.tool_calls = [tool_call]
|
||||
|
||||
# Create mock MCP server
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = "test_server"
|
||||
|
||||
# Create coder with max tool calls exceeded
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
coder.num_tool_calls = coder.max_tool_calls
|
||||
coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])]
|
||||
coder.mcp_servers = [mock_server]
|
||||
|
||||
# Test process_tool_calls
|
||||
result = coder.process_tool_calls(response)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Verify that warning was shown
|
||||
io.tool_warning.assert_called_once_with(
|
||||
f"Only {coder.max_tool_calls} tool calls allowed, stopping."
|
||||
)
|
||||
|
||||
def test_process_tool_calls_user_rejects(self):
|
||||
"""Test that process_tool_calls handles user rejection."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
io.confirm_ask = MagicMock(return_value=False)
|
||||
|
||||
# Create a tool call
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "test_id"
|
||||
tool_call.type = "function"
|
||||
tool_call.function = MagicMock()
|
||||
tool_call.function.name = "test_tool"
|
||||
|
||||
# Create a response with tool calls
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message = MagicMock()
|
||||
response.choices[0].message.tool_calls = [tool_call]
|
||||
|
||||
# Create mock MCP server
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = "test_server"
|
||||
|
||||
# Create coder with mock MCP tools
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])]
|
||||
coder.mcp_servers = [mock_server]
|
||||
|
||||
# Test process_tool_calls
|
||||
result = coder.process_tool_calls(response)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Verify that confirm_ask was called
|
||||
io.confirm_ask.assert_called_once_with("Run tools?")
|
||||
|
||||
# Verify that no messages were added
|
||||
self.assertEqual(len(coder.cur_messages), 0)
|
||||
|
||||
@patch("asyncio.run")
|
||||
def test_execute_tool_calls(self, mock_asyncio_run):
|
||||
"""Test that _execute_tool_calls executes tool calls correctly."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
coder = Coder.create(self.GPT35, "diff", io=io)
|
||||
|
||||
# Create mock server and tool call
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = "test_server"
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "test_id"
|
||||
tool_call.type = "function"
|
||||
tool_call.function = MagicMock()
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = '{"param": "value"}'
|
||||
|
||||
# Create server_tool_calls
|
||||
server_tool_calls = {mock_server: [tool_call]}
|
||||
|
||||
# Mock asyncio.run to return tool responses
|
||||
tool_responses = [
|
||||
[{"role": "tool", "tool_call_id": "test_id", "content": "Tool execution result"}]
|
||||
]
|
||||
mock_asyncio_run.return_value = tool_responses
|
||||
|
||||
# Test _execute_tool_calls directly
|
||||
result = coder._execute_tool_calls(server_tool_calls)
|
||||
|
||||
# Verify that asyncio.run was called
|
||||
mock_asyncio_run.assert_called_once()
|
||||
|
||||
# Verify that the correct tool responses were returned
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["role"], "tool")
|
||||
self.assertEqual(result[0]["tool_call_id"], "test_id")
|
||||
self.assertEqual(result[0]["content"], "Tool execution result")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1363,3 +1363,57 @@ class TestMain(TestCase):
|
|||
)
|
||||
for call in mock_io_instance.tool_warning.call_args_list:
|
||||
self.assertNotIn("Cost estimates may be inaccurate", call[0][0])
|
||||
|
||||
@patch("aider.coders.Coder.create")
|
||||
def test_mcp_servers_parsing(self, mock_coder_create):
|
||||
# Setup mock coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_coder_create.return_value = mock_coder_instance
|
||||
|
||||
# Test with --mcp-servers option
|
||||
with GitTemporaryDirectory():
|
||||
main(
|
||||
[
|
||||
"--mcp-servers",
|
||||
'{"mcpServers":{"git":{"command":"uvx","args":["mcp-server-git"]}}}',
|
||||
"--exit",
|
||||
"--yes",
|
||||
],
|
||||
input=DummyInput(),
|
||||
output=DummyOutput(),
|
||||
)
|
||||
|
||||
# Verify that Coder.create was called with mcp_servers parameter
|
||||
mock_coder_create.assert_called_once()
|
||||
_, kwargs = mock_coder_create.call_args
|
||||
self.assertIn("mcp_servers", kwargs)
|
||||
self.assertIsNotNone(kwargs["mcp_servers"])
|
||||
# At least one server should be in the list
|
||||
self.assertTrue(len(kwargs["mcp_servers"]) > 0)
|
||||
# First server should have a name attribute
|
||||
self.assertTrue(hasattr(kwargs["mcp_servers"][0], "name"))
|
||||
|
||||
# Test with --mcp-servers-file option
|
||||
mock_coder_create.reset_mock()
|
||||
|
||||
with GitTemporaryDirectory():
|
||||
# Create a temporary MCP servers file
|
||||
mcp_file = Path("mcp_servers.json")
|
||||
mcp_content = {"mcpServers": {"git": {"command": "uvx", "args": ["mcp-server-git"]}}}
|
||||
mcp_file.write_text(json.dumps(mcp_content))
|
||||
|
||||
main(
|
||||
["--mcp-servers-file", str(mcp_file), "--exit", "--yes"],
|
||||
input=DummyInput(),
|
||||
output=DummyOutput(),
|
||||
)
|
||||
|
||||
# Verify that Coder.create was called with mcp_servers parameter
|
||||
mock_coder_create.assert_called_once()
|
||||
_, kwargs = mock_coder_create.call_args
|
||||
self.assertIn("mcp_servers", kwargs)
|
||||
self.assertIsNotNone(kwargs["mcp_servers"])
|
||||
# At least one server should be in the list
|
||||
self.assertTrue(len(kwargs["mcp_servers"]) > 0)
|
||||
# First server should have a name attribute
|
||||
self.assertTrue(hasattr(kwargs["mcp_servers"][0], "name"))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue