Introduce and call MCP servers

This commit is contained in:
Quinlan Jager 2025-04-30 23:53:13 -07:00
parent c5414e2601
commit 162f49c98e
5 changed files with 121 additions and 2 deletions

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python
import asyncio
import base64
import hashlib
import json
@ -111,6 +112,8 @@ class Coder:
ignore_mentions = None
chat_language = None
file_watcher = None
mcp_servers = None
mcp_tools = None
@classmethod
def create(
@ -323,6 +326,7 @@ class Coder:
file_watcher=None,
auto_copy_context=False,
auto_accept_architect=True,
mcp_servers=None,
):
# Fill in a dummy Analytics if needed, but it is never .enable()'d
self.analytics = analytics if analytics is not None else Analytics()
@ -349,6 +353,7 @@ class Coder:
self.detect_urls = detect_urls
self.num_cache_warming_pings = num_cache_warming_pings
self.mcp_servers = mcp_servers
if not fnames:
fnames = []
@ -508,6 +513,30 @@ class Coder:
self.auto_test = auto_test
self.test_cmd = test_cmd
# Instantiate MCP tools
if self.mcp_servers:
from litellm import experimental_mcp_client
tools = []
print("GETTING SERVER TOOLS")
for server in self.mcp_servers:
print(f"Getting server tools: {server.name}")
async def get_server_tools(all_tools):
try:
session = await server.connect() # Use connect() directly
server_tools = await experimental_mcp_client.load_mcp_tools(
session=session, format="openai"
)
return all_tools + server_tools
finally:
await server.disconnect()
tools = asyncio.run(get_server_tools(tools))
self.mcp_tools = tools
print("All TOOLS")
print(tools)
# validate the functions jsonschema
if self.functions:
from jsonschema import Draft7Validator
@ -1461,6 +1490,8 @@ class Coder:
self.reflected_message = add_rel_files_message
return
print(content)
try:
if self.reply_completed():
return
@ -1688,12 +1719,15 @@ class Coder:
self.io.log_llm_history("TO LLM", format_messages(messages))
completion = None
try:
hash_object, completion = model.send_completion(
messages,
functions,
self.stream,
self.temperature,
# This could include any tools, but for now it is just MCP tools
tools=self.mcp_tools,
)
self.chat_completion_call_hashes.append(hash_object.hexdigest())

View file

@ -29,6 +29,7 @@ from aider.format_settings import format_settings, scrub_sensitive_info
from aider.history import ChatSummary
from aider.io import InputOutput
from aider.llm import litellm # noqa: F401; properly init litellm on launch
from aider.mcp.server import McpServer
from aider.models import ModelSettings
from aider.onboarding import offer_openrouter_oauth, select_default_model
from aider.repo import ANY_GIT_ERROR, GitRepo
@ -956,6 +957,10 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
analytics.event("auto_commits", enabled=bool(args.auto_commits))
try:
fetch_server = McpServer({"name": "fetch", "command": "uvx", "args": ["mcp-server-fetch"]})
git_server = McpServer({"name": "git", "command": "uvx", "args": ["mcp-server-git"]})
coder = Coder.create(
main_model=main_model,
edit_format=args.edit_format,
@ -988,6 +993,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
detect_urls=args.detect_urls,
auto_copy_context=args.copy_paste,
auto_accept_architect=args.auto_accept_architect,
mcp_servers=[fetch_server, git_server],
)
except UnknownEditFormat as err:
io.tool_error(str(err))

0
aider/mcp/__init__.py Normal file
View file

76
aider/mcp/server.py Normal file
View file

@ -0,0 +1,76 @@
import asyncio
import logging
import os
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
class McpServer:
"""
A client for MCP servers that provides tools to Aider coders. An McpServer class
is initialized per configured MCP Server
Current usage:
conn = await session.connect() # Use connect() directly
tools = await experimental_mcp_client.load_mcp_tools(session=s, format="openai")
await session.disconnect()
print(tools)
"""
def __init__(self, server_config):
"""Initialize the MCP tool provider.
Args:
server_config: Configuration for the MCP server
"""
self.config = server_config
self.name = server_config.get("name", "unnamed-server")
self.session = None
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.exit_stack = AsyncExitStack()
async def connect(self):
"""Connect to the MCP server and return the session.
If a session is already active, returns the existing session.
Otherwise, establishes a new connection and initializes the session.
Returns:
ClientSession: The active session
"""
if self.session is not None:
logging.info(f"Using existing session for MCP server: {self.name}")
return self.session
logging.info(f"Establishing new connection to MCP server: {self.name}")
command = self.config["command"]
server_params = StdioServerParameters(
command=command,
args=self.config["args"],
env={**os.environ, **self.config["env"]} if self.config.get("env") else None,
)
try:
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
read, write = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
self.session = session
return session
except Exception as e:
logging.error(f"Error initializing server {self.name}: {e}")
await self.disconnect()
raise
async def disconnect(self):
"""Disconnect from the MCP server and clean up resources."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
self.stdio_context = None
except Exception as e:
logging.error(f"Error during cleanup of server {self.name}: {e}")

View file

@ -767,7 +767,7 @@ class Model(ModelSettings):
def is_ollama(self):
return self.name.startswith("ollama/") or self.name.startswith("ollama_chat/")
def send_completion(self, messages, functions, stream, temperature=None):
def send_completion(self, messages, functions, stream, temperature=None, tools=None):
if os.environ.get("AIDER_SANITY_CHECK_TURNS"):
sanity_check_messages(messages)
@ -797,8 +797,11 @@ class Model(ModelSettings):
if self.is_ollama() and "num_ctx" not in kwargs:
num_ctx = int(self.token_count(messages) * 1.25) + 8192
kwargs["num_ctx"] = num_ctx
key = json.dumps(kwargs, sort_keys=True).encode()
if tools:
kwargs["tools"] = kwargs["tools"] + tools
key = json.dumps(kwargs, sort_keys=True).encode()
# dump(kwargs)
hash_object = hashlib.sha1(key)