mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-05 12:14:59 +00:00
Introduce and call MCP servers
This commit is contained in:
parent
c5414e2601
commit
162f49c98e
5 changed files with 121 additions and 2 deletions
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
@ -111,6 +112,8 @@ class Coder:
|
||||||
ignore_mentions = None
|
ignore_mentions = None
|
||||||
chat_language = None
|
chat_language = None
|
||||||
file_watcher = None
|
file_watcher = None
|
||||||
|
mcp_servers = None
|
||||||
|
mcp_tools = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
|
@ -323,6 +326,7 @@ class Coder:
|
||||||
file_watcher=None,
|
file_watcher=None,
|
||||||
auto_copy_context=False,
|
auto_copy_context=False,
|
||||||
auto_accept_architect=True,
|
auto_accept_architect=True,
|
||||||
|
mcp_servers=None,
|
||||||
):
|
):
|
||||||
# Fill in a dummy Analytics if needed, but it is never .enable()'d
|
# Fill in a dummy Analytics if needed, but it is never .enable()'d
|
||||||
self.analytics = analytics if analytics is not None else Analytics()
|
self.analytics = analytics if analytics is not None else Analytics()
|
||||||
|
@ -349,6 +353,7 @@ class Coder:
|
||||||
self.detect_urls = detect_urls
|
self.detect_urls = detect_urls
|
||||||
|
|
||||||
self.num_cache_warming_pings = num_cache_warming_pings
|
self.num_cache_warming_pings = num_cache_warming_pings
|
||||||
|
self.mcp_servers = mcp_servers
|
||||||
|
|
||||||
if not fnames:
|
if not fnames:
|
||||||
fnames = []
|
fnames = []
|
||||||
|
@ -508,6 +513,30 @@ class Coder:
|
||||||
self.auto_test = auto_test
|
self.auto_test = auto_test
|
||||||
self.test_cmd = test_cmd
|
self.test_cmd = test_cmd
|
||||||
|
|
||||||
|
# Instantiate MCP tools
|
||||||
|
if self.mcp_servers:
|
||||||
|
from litellm import experimental_mcp_client
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
print("GETTING SERVER TOOLS")
|
||||||
|
for server in self.mcp_servers:
|
||||||
|
print(f"Getting server tools: {server.name}")
|
||||||
|
|
||||||
|
async def get_server_tools(all_tools):
|
||||||
|
try:
|
||||||
|
session = await server.connect() # Use connect() directly
|
||||||
|
server_tools = await experimental_mcp_client.load_mcp_tools(
|
||||||
|
session=session, format="openai"
|
||||||
|
)
|
||||||
|
return all_tools + server_tools
|
||||||
|
finally:
|
||||||
|
await server.disconnect()
|
||||||
|
|
||||||
|
tools = asyncio.run(get_server_tools(tools))
|
||||||
|
|
||||||
|
self.mcp_tools = tools
|
||||||
|
print("All TOOLS")
|
||||||
|
print(tools)
|
||||||
# validate the functions jsonschema
|
# validate the functions jsonschema
|
||||||
if self.functions:
|
if self.functions:
|
||||||
from jsonschema import Draft7Validator
|
from jsonschema import Draft7Validator
|
||||||
|
@ -1461,6 +1490,8 @@ class Coder:
|
||||||
self.reflected_message = add_rel_files_message
|
self.reflected_message = add_rel_files_message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
print(content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.reply_completed():
|
if self.reply_completed():
|
||||||
return
|
return
|
||||||
|
@ -1688,12 +1719,15 @@ class Coder:
|
||||||
self.io.log_llm_history("TO LLM", format_messages(messages))
|
self.io.log_llm_history("TO LLM", format_messages(messages))
|
||||||
|
|
||||||
completion = None
|
completion = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hash_object, completion = model.send_completion(
|
hash_object, completion = model.send_completion(
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
self.stream,
|
self.stream,
|
||||||
self.temperature,
|
self.temperature,
|
||||||
|
# This could include any tools, but for now it is just MCP tools
|
||||||
|
tools=self.mcp_tools,
|
||||||
)
|
)
|
||||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ from aider.format_settings import format_settings, scrub_sensitive_info
|
||||||
from aider.history import ChatSummary
|
from aider.history import ChatSummary
|
||||||
from aider.io import InputOutput
|
from aider.io import InputOutput
|
||||||
from aider.llm import litellm # noqa: F401; properly init litellm on launch
|
from aider.llm import litellm # noqa: F401; properly init litellm on launch
|
||||||
|
from aider.mcp.server import McpServer
|
||||||
from aider.models import ModelSettings
|
from aider.models import ModelSettings
|
||||||
from aider.onboarding import offer_openrouter_oauth, select_default_model
|
from aider.onboarding import offer_openrouter_oauth, select_default_model
|
||||||
from aider.repo import ANY_GIT_ERROR, GitRepo
|
from aider.repo import ANY_GIT_ERROR, GitRepo
|
||||||
|
@ -956,6 +957,10 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
|
||||||
analytics.event("auto_commits", enabled=bool(args.auto_commits))
|
analytics.event("auto_commits", enabled=bool(args.auto_commits))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
fetch_server = McpServer({"name": "fetch", "command": "uvx", "args": ["mcp-server-fetch"]})
|
||||||
|
|
||||||
|
git_server = McpServer({"name": "git", "command": "uvx", "args": ["mcp-server-git"]})
|
||||||
|
|
||||||
coder = Coder.create(
|
coder = Coder.create(
|
||||||
main_model=main_model,
|
main_model=main_model,
|
||||||
edit_format=args.edit_format,
|
edit_format=args.edit_format,
|
||||||
|
@ -988,6 +993,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
|
||||||
detect_urls=args.detect_urls,
|
detect_urls=args.detect_urls,
|
||||||
auto_copy_context=args.copy_paste,
|
auto_copy_context=args.copy_paste,
|
||||||
auto_accept_architect=args.auto_accept_architect,
|
auto_accept_architect=args.auto_accept_architect,
|
||||||
|
mcp_servers=[fetch_server, git_server],
|
||||||
)
|
)
|
||||||
except UnknownEditFormat as err:
|
except UnknownEditFormat as err:
|
||||||
io.tool_error(str(err))
|
io.tool_error(str(err))
|
||||||
|
|
0
aider/mcp/__init__.py
Normal file
0
aider/mcp/__init__.py
Normal file
76
aider/mcp/server.py
Normal file
76
aider/mcp/server.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
||||||
|
|
||||||
|
class McpServer:
|
||||||
|
"""
|
||||||
|
A client for MCP servers that provides tools to Aider coders. An McpServer class
|
||||||
|
is initialized per configured MCP Server
|
||||||
|
|
||||||
|
Current usage:
|
||||||
|
|
||||||
|
conn = await session.connect() # Use connect() directly
|
||||||
|
tools = await experimental_mcp_client.load_mcp_tools(session=s, format="openai")
|
||||||
|
await session.disconnect()
|
||||||
|
print(tools)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, server_config):
|
||||||
|
"""Initialize the MCP tool provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_config: Configuration for the MCP server
|
||||||
|
"""
|
||||||
|
self.config = server_config
|
||||||
|
self.name = server_config.get("name", "unnamed-server")
|
||||||
|
self.session = None
|
||||||
|
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
||||||
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to the MCP server and return the session.
|
||||||
|
|
||||||
|
If a session is already active, returns the existing session.
|
||||||
|
Otherwise, establishes a new connection and initializes the session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClientSession: The active session
|
||||||
|
"""
|
||||||
|
if self.session is not None:
|
||||||
|
logging.info(f"Using existing session for MCP server: {self.name}")
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
logging.info(f"Establishing new connection to MCP server: {self.name}")
|
||||||
|
command = self.config["command"]
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command=command,
|
||||||
|
args=self.config["args"],
|
||||||
|
env={**os.environ, **self.config["env"]} if self.config.get("env") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||||
|
read, write = stdio_transport
|
||||||
|
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||||
|
await session.initialize()
|
||||||
|
self.session = session
|
||||||
|
return session
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error initializing server {self.name}: {e}")
|
||||||
|
await self.disconnect()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""Disconnect from the MCP server and clean up resources."""
|
||||||
|
async with self._cleanup_lock:
|
||||||
|
try:
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
self.session = None
|
||||||
|
self.stdio_context = None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error during cleanup of server {self.name}: {e}")
|
|
@ -767,7 +767,7 @@ class Model(ModelSettings):
|
||||||
def is_ollama(self):
|
def is_ollama(self):
|
||||||
return self.name.startswith("ollama/") or self.name.startswith("ollama_chat/")
|
return self.name.startswith("ollama/") or self.name.startswith("ollama_chat/")
|
||||||
|
|
||||||
def send_completion(self, messages, functions, stream, temperature=None):
|
def send_completion(self, messages, functions, stream, temperature=None, tools=None):
|
||||||
if os.environ.get("AIDER_SANITY_CHECK_TURNS"):
|
if os.environ.get("AIDER_SANITY_CHECK_TURNS"):
|
||||||
sanity_check_messages(messages)
|
sanity_check_messages(messages)
|
||||||
|
|
||||||
|
@ -797,8 +797,11 @@ class Model(ModelSettings):
|
||||||
if self.is_ollama() and "num_ctx" not in kwargs:
|
if self.is_ollama() and "num_ctx" not in kwargs:
|
||||||
num_ctx = int(self.token_count(messages) * 1.25) + 8192
|
num_ctx = int(self.token_count(messages) * 1.25) + 8192
|
||||||
kwargs["num_ctx"] = num_ctx
|
kwargs["num_ctx"] = num_ctx
|
||||||
key = json.dumps(kwargs, sort_keys=True).encode()
|
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = kwargs["tools"] + tools
|
||||||
|
|
||||||
|
key = json.dumps(kwargs, sort_keys=True).encode()
|
||||||
# dump(kwargs)
|
# dump(kwargs)
|
||||||
|
|
||||||
hash_object = hashlib.sha1(key)
|
hash_object = hashlib.sha1(key)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue