diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 675570c60..617c78d13 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +import asyncio import base64 import hashlib import json @@ -111,6 +112,8 @@ class Coder: ignore_mentions = None chat_language = None file_watcher = None + mcp_servers = None + mcp_tools = None @classmethod def create( @@ -323,6 +326,7 @@ class Coder: file_watcher=None, auto_copy_context=False, auto_accept_architect=True, + mcp_servers=None, ): # Fill in a dummy Analytics if needed, but it is never .enable()'d self.analytics = analytics if analytics is not None else Analytics() @@ -349,6 +353,7 @@ class Coder: self.detect_urls = detect_urls self.num_cache_warming_pings = num_cache_warming_pings + self.mcp_servers = mcp_servers if not fnames: fnames = [] @@ -508,6 +513,30 @@ class Coder: self.auto_test = auto_test self.test_cmd = test_cmd + # Instantiate MCP tools + if self.mcp_servers: + from litellm import experimental_mcp_client + + tools = [] + print("GETTING SERVER TOOLS") + for server in self.mcp_servers: + print(f"Getting server tools: {server.name}") + + async def get_server_tools(all_tools): + try: + session = await server.connect() # Use connect() directly + server_tools = await experimental_mcp_client.load_mcp_tools( + session=session, format="openai" + ) + return all_tools + server_tools + finally: + await server.disconnect() + + tools = asyncio.run(get_server_tools(tools)) + + self.mcp_tools = tools + print("All TOOLS") + print(tools) # validate the functions jsonschema if self.functions: from jsonschema import Draft7Validator @@ -1461,6 +1490,8 @@ class Coder: self.reflected_message = add_rel_files_message return + print(content) + try: if self.reply_completed(): return @@ -1688,12 +1719,15 @@ class Coder: self.io.log_llm_history("TO LLM", format_messages(messages)) completion = None + try: hash_object, completion = model.send_completion( messages, functions, self.stream, self.temperature, + # This could include any tools, but for now it is just MCP tools + tools=self.mcp_tools, ) self.chat_completion_call_hashes.append(hash_object.hexdigest()) diff --git a/aider/main.py b/aider/main.py index 89286e1de..5648ac966 100644 --- a/aider/main.py +++ b/aider/main.py @@ -29,6 +29,7 @@ from aider.format_settings import format_settings, scrub_sensitive_info from aider.history import ChatSummary from aider.io import InputOutput from aider.llm import litellm # noqa: F401; properly init litellm on launch +from aider.mcp.server import McpServer from aider.models import ModelSettings from aider.onboarding import offer_openrouter_oauth, select_default_model from aider.repo import ANY_GIT_ERROR, GitRepo @@ -956,6 +957,10 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F analytics.event("auto_commits", enabled=bool(args.auto_commits)) try: + fetch_server = McpServer({"name": "fetch", "command": "uvx", "args": ["mcp-server-fetch"]}) + + git_server = McpServer({"name": "git", "command": "uvx", "args": ["mcp-server-git"]}) + coder = Coder.create( main_model=main_model, edit_format=args.edit_format, @@ -988,6 +993,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F detect_urls=args.detect_urls, auto_copy_context=args.copy_paste, auto_accept_architect=args.auto_accept_architect, + mcp_servers=[fetch_server, git_server], ) except UnknownEditFormat as err: io.tool_error(str(err)) diff --git a/aider/mcp/__init__.py b/aider/mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/aider/mcp/server.py b/aider/mcp/server.py new file mode 100644 index 000000000..db0d56c58 --- /dev/null +++ b/aider/mcp/server.py @@ -0,0 +1,76 @@ +import asyncio +import logging +import os +from contextlib import AsyncExitStack + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + + +class McpServer: + """ + A client for MCP servers that provides tools to Aider coders. An McpServer class + is initialized per configured MCP Server + + Current usage: + + conn = await session.connect() # Use connect() directly + tools = await experimental_mcp_client.load_mcp_tools(session=s, format="openai") + await session.disconnect() + print(tools) + """ + + def __init__(self, server_config): + """Initialize the MCP tool provider. + + Args: + server_config: Configuration for the MCP server + """ + self.config = server_config + self.name = server_config.get("name", "unnamed-server") + self.session = None + self._cleanup_lock: asyncio.Lock = asyncio.Lock() + self.exit_stack = AsyncExitStack() + + async def connect(self): + """Connect to the MCP server and return the session. + + If a session is already active, returns the existing session. + Otherwise, establishes a new connection and initializes the session. + + Returns: + ClientSession: The active session + """ + if self.session is not None: + logging.info(f"Using existing session for MCP server: {self.name}") + return self.session + + logging.info(f"Establishing new connection to MCP server: {self.name}") + command = self.config["command"] + server_params = StdioServerParameters( + command=command, + args=self.config["args"], + env={**os.environ, **self.config["env"]} if self.config.get("env") else None, + ) + + try: + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + read, write = stdio_transport + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + self.session = session + return session + except Exception as e: + logging.error(f"Error initializing server {self.name}: {e}") + await self.disconnect() + raise + + async def disconnect(self): + """Disconnect from the MCP server and clean up resources.""" + async with self._cleanup_lock: + try: + await self.exit_stack.aclose() + self.session = None + self.stdio_context = None + except Exception as e: + logging.error(f"Error during cleanup of server {self.name}: {e}") diff --git a/aider/models.py b/aider/models.py index dd0abd452..72407d16e 100644 --- a/aider/models.py +++ b/aider/models.py @@ -767,7 +767,7 @@ class Model(ModelSettings): def is_ollama(self): return self.name.startswith("ollama/") or self.name.startswith("ollama_chat/") - def send_completion(self, messages, functions, stream, temperature=None): + def send_completion(self, messages, functions, stream, temperature=None, tools=None): if os.environ.get("AIDER_SANITY_CHECK_TURNS"): sanity_check_messages(messages) @@ -797,8 +797,11 @@ class Model(ModelSettings): if self.is_ollama() and "num_ctx" not in kwargs: num_ctx = int(self.token_count(messages) * 1.25) + 8192 kwargs["num_ctx"] = num_ctx - key = json.dumps(kwargs, sort_keys=True).encode() + if tools: + kwargs["tools"] = kwargs["tools"] + tools + + key = json.dumps(kwargs, sort_keys=True).encode() # dump(kwargs) hash_object = hashlib.sha1(key)