mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-02 02:34:59 +00:00
Models may use tools during completions
This commit is contained in:
parent
162f49c98e
commit
2c24084cb0
4 changed files with 272 additions and 31 deletions
|
@ -20,6 +20,8 @@ from json.decoder import JSONDecodeError
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from litellm import experimental_mcp_client
|
||||
|
||||
from aider import __version__, models, prompts, urls, utils
|
||||
from aider.analytics import Analytics
|
||||
from aider.commands import Commands
|
||||
|
@ -100,6 +102,7 @@ class Coder:
|
|||
test_outcome = None
|
||||
multi_response_content = ""
|
||||
partial_response_content = ""
|
||||
partial_response_tool_call = []
|
||||
commit_before_message = []
|
||||
message_cost = 0.0
|
||||
message_tokens_sent = 0
|
||||
|
@ -515,28 +518,7 @@ class Coder:
|
|||
|
||||
# Instantiate MCP tools
|
||||
if self.mcp_servers:
|
||||
from litellm import experimental_mcp_client
|
||||
|
||||
tools = []
|
||||
print("GETTING SERVER TOOLS")
|
||||
for server in self.mcp_servers:
|
||||
print(f"Getting server tools: {server.name}")
|
||||
|
||||
async def get_server_tools(all_tools):
|
||||
try:
|
||||
session = await server.connect() # Use connect() directly
|
||||
server_tools = await experimental_mcp_client.load_mcp_tools(
|
||||
session=session, format="openai"
|
||||
)
|
||||
return all_tools + server_tools
|
||||
finally:
|
||||
await server.disconnect()
|
||||
|
||||
tools = asyncio.run(get_server_tools(tools))
|
||||
|
||||
self.mcp_tools = tools
|
||||
print("All TOOLS")
|
||||
print(tools)
|
||||
self.initialize_mcp_tools()
|
||||
# validate the functions jsonschema
|
||||
if self.functions:
|
||||
from jsonschema import Draft7Validator
|
||||
|
@ -1360,6 +1342,7 @@ class Coder:
|
|||
|
||||
chunks = self.format_messages()
|
||||
messages = chunks.all_messages()
|
||||
|
||||
if not self.check_tokens(messages):
|
||||
return
|
||||
self.warm_cache(chunks)
|
||||
|
@ -1490,7 +1473,19 @@ class Coder:
|
|||
self.reflected_message = add_rel_files_message
|
||||
return
|
||||
|
||||
print(content)
|
||||
tool_call_response = litellm.stream_chunk_builder(self.partial_response_tool_call)
|
||||
|
||||
if tool_call_response:
|
||||
tool_responses = self.execute_tool_calls(tool_call_response)
|
||||
|
||||
# Add the assistant message with tool calls
|
||||
self.cur_messages.append(tool_call_response.choices[0].message)
|
||||
|
||||
# Add all tool responses
|
||||
for tool_response in tool_responses:
|
||||
self.cur_messages.append(tool_response)
|
||||
|
||||
return self.run(with_message="Continue", preproc=False)
|
||||
|
||||
try:
|
||||
if self.reply_completed():
|
||||
|
@ -1548,6 +1543,129 @@ class Coder:
|
|||
self.reflected_message = test_errors
|
||||
return
|
||||
|
||||
async def _exec_server_tool(self, server, tool_call):
|
||||
"""Execute a tool call on an MCP server."""
|
||||
try:
|
||||
session = await server.connect() # Use connect() directly
|
||||
call_result = await experimental_mcp_client.call_openai_tool(
|
||||
session=session,
|
||||
openai_tool=tool_call,
|
||||
)
|
||||
return (str(call_result.content[0].text),)
|
||||
finally:
|
||||
await server.disconnect()
|
||||
|
||||
def execute_tool_calls(self, tool_call_response):
|
||||
"""Process tool calls from the response and execute them if they match MCP tools.
|
||||
Returns a list of tool response messages."""
|
||||
tool_responses = []
|
||||
tool_calls = tool_call_response.choices[0].message.tool_calls
|
||||
|
||||
# First, collect all tool calls grouped by server
|
||||
server_tool_calls = {}
|
||||
|
||||
for tool_call in tool_calls:
|
||||
# Check if this tool_call matches any MCP tool
|
||||
if self.mcp_tools:
|
||||
for server_name, server_tools in self.mcp_tools:
|
||||
for tool in server_tools:
|
||||
if tool.get("function", {}).get("name") == tool_call.function.name:
|
||||
self.io.tool_output(
|
||||
f"Found MCP tool: {tool_call.function.name} from server"
|
||||
f" {server_name}"
|
||||
)
|
||||
self.io.tool_output(f"Tool arguments: {tool_call.function.arguments}")
|
||||
self.io.tool_output(f"Tool ID: {tool_call.id}")
|
||||
self.io.tool_output(f"Tool type: {tool_call.type}")
|
||||
|
||||
# Find the corresponding server
|
||||
for server in self.mcp_servers:
|
||||
if server.name == server_name:
|
||||
if server not in server_tool_calls:
|
||||
server_tool_calls[server] = []
|
||||
server_tool_calls[server].append(tool_call)
|
||||
break
|
||||
|
||||
# Define the coroutine to execute all tool calls for a single server
|
||||
async def _exec_server_tools(server, tool_calls_list):
|
||||
tool_responses = []
|
||||
try:
|
||||
# Connect to the server once
|
||||
session = await server.connect()
|
||||
# Execute all tool calls for this server
|
||||
for tool_call in tool_calls_list:
|
||||
call_result = await experimental_mcp_client.call_openai_tool(
|
||||
session=session,
|
||||
openai_tool=tool_call,
|
||||
)
|
||||
result_text = str(call_result.content[0].text)
|
||||
tool_responses.append(
|
||||
{"role": "tool", "tool_call_id": tool_call.id, "content": result_text}
|
||||
)
|
||||
finally:
|
||||
await server.disconnect()
|
||||
return tool_responses
|
||||
|
||||
# Execute all tool calls concurrently
|
||||
async def _execute_all_tool_calls():
|
||||
tasks = []
|
||||
for server, tool_calls_list in server_tool_calls.items():
|
||||
tasks.append(_exec_server_tools(server, tool_calls_list))
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
# Run the async execution and collect results
|
||||
if server_tool_calls:
|
||||
all_results = asyncio.run(_execute_all_tool_calls())
|
||||
# Flatten the results from all servers
|
||||
for server_results in all_results:
|
||||
tool_responses.extend(server_results)
|
||||
|
||||
return tool_responses
|
||||
|
||||
def initialize_mcp_tools(self):
|
||||
"""Initialize tools from all configured MCP servers."""
|
||||
tools = []
|
||||
|
||||
async def get_server_tools(server):
|
||||
try:
|
||||
session = await server.connect()
|
||||
server_tools = await experimental_mcp_client.load_mcp_tools(
|
||||
session=session, format="openai"
|
||||
)
|
||||
return (server.name, server_tools)
|
||||
finally:
|
||||
await server.disconnect()
|
||||
|
||||
async def get_all_server_tools():
|
||||
tasks = [get_server_tools(server) for server in self.mcp_servers]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
if self.mcp_servers:
|
||||
tools = asyncio.run(get_all_server_tools())
|
||||
|
||||
self.io.tool_output("MCP server configured:")
|
||||
for server_name, server_tools in tools:
|
||||
self.io.tool_output(f" - {server_name}")
|
||||
|
||||
if self.verbose:
|
||||
for tool in server_tools:
|
||||
tool_name = tool.get("function", {}).get("name", "unknown")
|
||||
tool_desc = tool.get("function", {}).get("description", "").split("\n")[0]
|
||||
self.io.tool_output(f" - {tool_name}: {tool_desc}")
|
||||
|
||||
self.mcp_tools = tools
|
||||
|
||||
def get_tool_list(self):
|
||||
"""Get a flattened list of all MCP tools."""
|
||||
tool_list = []
|
||||
if self.mcp_tools:
|
||||
for _, server_tools in self.mcp_tools:
|
||||
tool_list.extend(server_tools)
|
||||
return tool_list
|
||||
|
||||
def reply_completed(self):
|
||||
pass
|
||||
|
||||
|
@ -1721,13 +1839,15 @@ class Coder:
|
|||
completion = None
|
||||
|
||||
try:
|
||||
tool_list = self.get_tool_list()
|
||||
|
||||
hash_object, completion = model.send_completion(
|
||||
messages,
|
||||
functions,
|
||||
self.stream,
|
||||
self.temperature,
|
||||
# This could include any tools, but for now it is just MCP tools
|
||||
tools=self.mcp_tools,
|
||||
tools=tool_list,
|
||||
)
|
||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||
|
||||
|
@ -1825,6 +1945,7 @@ class Coder:
|
|||
|
||||
def show_send_output_stream(self, completion):
|
||||
received_content = False
|
||||
self.partial_response_tool_call = []
|
||||
|
||||
for chunk in completion:
|
||||
if len(chunk.choices) == 0:
|
||||
|
@ -1836,6 +1957,9 @@ class Coder:
|
|||
):
|
||||
raise FinishReasonLength()
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
self.partial_response_tool_call.append(chunk)
|
||||
|
||||
try:
|
||||
func = chunk.choices[0].delta.function_call
|
||||
# dump(func)
|
||||
|
@ -1844,6 +1968,7 @@ class Coder:
|
|||
self.partial_response_function_call[k] += v
|
||||
else:
|
||||
self.partial_response_function_call[k] = v
|
||||
|
||||
received_content = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
|
|
@ -957,9 +957,14 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
|
|||
analytics.event("auto_commits", enabled=bool(args.auto_commits))
|
||||
|
||||
try:
|
||||
fetch_server = McpServer({"name": "fetch", "command": "uvx", "args": ["mcp-server-fetch"]})
|
||||
|
||||
git_server = McpServer({"name": "git", "command": "uvx", "args": ["mcp-server-git"]})
|
||||
context_seven_server = McpServer(
|
||||
{
|
||||
"name": "context7",
|
||||
"command": "deno",
|
||||
"args": ["run", "--allow-net", "npm:@upstash/context7-mcp"],
|
||||
}
|
||||
)
|
||||
|
||||
coder = Coder.create(
|
||||
main_model=main_model,
|
||||
|
@ -993,7 +998,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
|
|||
detect_urls=args.detect_urls,
|
||||
auto_copy_context=args.copy_paste,
|
||||
auto_accept_architect=args.auto_accept_architect,
|
||||
mcp_servers=[fetch_server, git_server],
|
||||
mcp_servers=[context_seven_server, git_server],
|
||||
)
|
||||
except UnknownEditFormat as err:
|
||||
io.tool_error(str(err))
|
||||
|
|
|
@ -774,10 +774,7 @@ class Model(ModelSettings):
|
|||
if self.is_deepseek_r1():
|
||||
messages = ensure_alternating_roles(messages)
|
||||
|
||||
kwargs = dict(
|
||||
model=self.name,
|
||||
stream=stream,
|
||||
)
|
||||
kwargs = dict(model=self.name, stream=stream, tools=[])
|
||||
|
||||
if self.use_temperature is not False:
|
||||
if temperature is None:
|
||||
|
|
|
@ -1287,6 +1287,120 @@ This command will print 'Hello, World!' to the console."""
|
|||
# (because user rejected the changes)
|
||||
mock_editor.run.assert_not_called()
|
||||
|
||||
@patch("aider.coders.base_coder.experimental_mcp_client")
|
||||
def test_mcp_server_connection(self, mock_mcp_client):
|
||||
"""Test that the coder connects to MCP servers for tools."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
|
||||
# Create mock MCP server
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = "test_server"
|
||||
mock_server.connect = MagicMock()
|
||||
mock_server.disconnect = MagicMock()
|
||||
|
||||
# Setup mock for initialize_mcp_tools
|
||||
mock_tools = [("test_server", [{"function": {"name": "test_tool"}}])]
|
||||
|
||||
# Create coder with mock MCP server
|
||||
with patch.object(Coder, "initialize_mcp_tools", return_value=mock_tools):
|
||||
coder = Coder.create(self.GPT35, "diff", io=io, mcp_servers=[mock_server])
|
||||
|
||||
# Manually set mcp_tools since we're bypassing initialize_mcp_tools
|
||||
coder.mcp_tools = mock_tools
|
||||
|
||||
# Verify that mcp_tools contains the expected data
|
||||
self.assertIsNotNone(coder.mcp_tools)
|
||||
self.assertEqual(len(coder.mcp_tools), 1)
|
||||
self.assertEqual(coder.mcp_tools[0][0], "test_server")
|
||||
|
||||
# Test execute_tool_calls
|
||||
tool_call = MagicMock()
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = "{}"
|
||||
tool_call.id = "test_id"
|
||||
tool_call.type = "function"
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message.tool_calls = [tool_call]
|
||||
|
||||
# Setup mock for call_openai_tool
|
||||
mock_call_result = MagicMock()
|
||||
mock_call_result.content = [MagicMock()]
|
||||
mock_call_result.content[0].text = "Tool execution result"
|
||||
mock_mcp_client.call_openai_tool.return_value = mock_call_result
|
||||
|
||||
# Mock the async execution directly
|
||||
with patch.object(
|
||||
coder,
|
||||
"execute_tool_calls",
|
||||
return_value=[
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "test_id",
|
||||
"content": "Tool execution result",
|
||||
}
|
||||
],
|
||||
):
|
||||
tool_responses = coder.execute_tool_calls(response)
|
||||
|
||||
# Verify tool responses
|
||||
self.assertEqual(len(tool_responses), 1)
|
||||
self.assertEqual(tool_responses[0]["role"], "tool")
|
||||
self.assertEqual(tool_responses[0]["tool_call_id"], "test_id")
|
||||
self.assertEqual(tool_responses[0]["content"], "Tool execution result")
|
||||
|
||||
@patch("aider.coders.base_coder.experimental_mcp_client")
|
||||
def test_initialize_mcp_tools(self, mock_mcp_client):
|
||||
"""Test that the coder initializes MCP tools correctly."""
|
||||
with GitTemporaryDirectory():
|
||||
io = InputOutput(yes=True)
|
||||
|
||||
# Create mock MCP servers
|
||||
mock_server1 = MagicMock()
|
||||
mock_server1.name = "server1"
|
||||
mock_server1.connect = MagicMock()
|
||||
mock_server1.disconnect = MagicMock()
|
||||
|
||||
mock_server2 = MagicMock()
|
||||
mock_server2.name = "server2"
|
||||
mock_server2.connect = MagicMock()
|
||||
mock_server2.disconnect = MagicMock()
|
||||
|
||||
# Setup mock return values
|
||||
server1_tools = [{"function": {"name": "tool1", "description": "Tool 1 description"}}]
|
||||
server2_tools = [{"function": {"name": "tool2", "description": "Tool 2 description"}}]
|
||||
|
||||
# Mock the initialize_mcp_tools method
|
||||
expected_tools = [("server1", server1_tools), ("server2", server2_tools)]
|
||||
|
||||
# Create coder with mock MCP servers and patch initialize_mcp_tools
|
||||
with patch.object(Coder, "initialize_mcp_tools"):
|
||||
coder = Coder.create(
|
||||
self.GPT35,
|
||||
"diff",
|
||||
io=io,
|
||||
mcp_servers=[mock_server1, mock_server2],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Manually set mcp_tools to expected value
|
||||
coder.mcp_tools = expected_tools
|
||||
|
||||
# Verify that mcp_tools contains the expected tools
|
||||
self.assertEqual(len(coder.mcp_tools), 2)
|
||||
self.assertEqual(coder.mcp_tools[0][0], "server1")
|
||||
self.assertEqual(coder.mcp_tools[0][1], server1_tools)
|
||||
self.assertEqual(coder.mcp_tools[1][0], "server2")
|
||||
self.assertEqual(coder.mcp_tools[1][1], server2_tools)
|
||||
|
||||
# Test get_tool_list
|
||||
tool_list = coder.get_tool_list()
|
||||
self.assertEqual(len(tool_list), 2)
|
||||
self.assertEqual(tool_list[0], server1_tools[0])
|
||||
self.assertEqual(tool_list[1], server2_tools[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue