mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-08 21:55:00 +00:00
Use a parens parser and then Python's ast.parse to parse tool calls robustly
This commit is contained in:
parent
4339c73774
commit
765002d486
1 changed files with 135 additions and 47 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
import ast
|
||||||
import re
|
import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
|
@ -350,66 +351,149 @@ class NavigatorCoder(Coder):
|
||||||
call_count = 0
|
call_count = 0
|
||||||
max_calls = self.max_tool_calls
|
max_calls = self.max_tool_calls
|
||||||
|
|
||||||
# Regex to find tool calls: [tool_call(name, key=value, key="value", ...)]
|
# Find tool calls using a more robust method
|
||||||
# It captures the tool name and the arguments string.
|
processed_content = ""
|
||||||
# It handles quoted and unquoted values.
|
last_index = 0
|
||||||
tool_call_pattern = re.compile(
|
start_marker = "[tool_call("
|
||||||
r"\[tool_call\(\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*" # Tool name
|
end_marker = "]" # The parenthesis balancing finds the ')', we just need the final ']'
|
||||||
r"(?:,\s*(.*?))?\s*\)\]" # Optional arguments string (non-greedy)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Regex to parse key=value pairs within the arguments string
|
while True:
|
||||||
# Handles key=value, key="value", key='value'
|
start_pos = content.find(start_marker, last_index)
|
||||||
# Allows values to contain commas if quoted
|
if start_pos == -1:
|
||||||
args_pattern = re.compile(
|
processed_content += content[last_index:]
|
||||||
r"([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*" # Key
|
break
|
||||||
r"(?:\"(.*?)\"|\'(.*?)\'|([^,\s\'\"]*))" # Value (quoted or unquoted)
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_indices = set() # Keep track of processed match ranges
|
# Append content before the tool call
|
||||||
|
processed_content += content[last_index:start_pos]
|
||||||
|
|
||||||
for match in tool_call_pattern.finditer(content):
|
scan_start_pos = start_pos + len(start_marker)
|
||||||
# Skip overlapping matches if a previous match already covered this area
|
paren_level = 1
|
||||||
if any(match.start() >= start and match.end() <= end for start, end in processed_indices):
|
in_single_quotes = False
|
||||||
|
in_double_quotes = False
|
||||||
|
escaped = False
|
||||||
|
end_paren_pos = -1
|
||||||
|
|
||||||
|
# Scan to find the matching closing parenthesis, respecting quotes
|
||||||
|
for i in range(scan_start_pos, len(content)):
|
||||||
|
char = content[i]
|
||||||
|
|
||||||
|
if escaped:
|
||||||
|
escaped = False
|
||||||
|
elif char == '\\':
|
||||||
|
escaped = True
|
||||||
|
elif char == "'" and not in_double_quotes:
|
||||||
|
in_single_quotes = not in_single_quotes
|
||||||
|
elif char == '"' and not in_single_quotes:
|
||||||
|
in_double_quotes = not in_double_quotes
|
||||||
|
elif char == '(' and not in_single_quotes and not in_double_quotes:
|
||||||
|
paren_level += 1
|
||||||
|
elif char == ')' and not in_single_quotes and not in_double_quotes:
|
||||||
|
paren_level -= 1
|
||||||
|
if paren_level == 0:
|
||||||
|
end_paren_pos = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check for the end marker after the closing parenthesis, skipping whitespace
|
||||||
|
expected_end_marker_start = end_paren_pos + 1
|
||||||
|
actual_end_marker_start = -1
|
||||||
|
end_marker_found = False
|
||||||
|
if end_paren_pos != -1: # Only search if we found a closing parenthesis
|
||||||
|
for j in range(expected_end_marker_start, len(content)):
|
||||||
|
if not content[j].isspace():
|
||||||
|
actual_end_marker_start = j
|
||||||
|
# Check if the found character is the end marker ']'
|
||||||
|
if content[actual_end_marker_start] == end_marker:
|
||||||
|
end_marker_found = True
|
||||||
|
break # Stop searching after first non-whitespace char
|
||||||
|
|
||||||
|
if not end_marker_found:
|
||||||
|
# Malformed call: couldn't find matching ')' or the subsequent ']'
|
||||||
|
self.io.tool_warning(f"Malformed tool call starting at index {start_pos}. Skipping (end_paren_pos={end_paren_pos}, end_marker_found={end_marker_found}).")
|
||||||
|
# Append the start marker itself to processed content so it's not lost
|
||||||
|
processed_content += start_marker
|
||||||
|
last_index = scan_start_pos # Continue searching after the marker
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Found a potential tool call
|
||||||
|
# Adjust full_match_str and last_index based on the actual end marker ']' position
|
||||||
|
full_match_str = content[start_pos : actual_end_marker_start + 1] # End marker ']' is 1 char
|
||||||
|
inner_content = content[scan_start_pos:end_paren_pos].strip()
|
||||||
|
last_index = actual_end_marker_start + 1 # Move past the processed call (including ']')
|
||||||
|
|
||||||
|
|
||||||
call_count += 1
|
call_count += 1
|
||||||
if call_count > max_calls:
|
if call_count > max_calls:
|
||||||
self.io.tool_warning(f"Exceeded maximum tool calls ({max_calls}). Skipping remaining calls.")
|
self.io.tool_warning(f"Exceeded maximum tool calls ({max_calls}). Skipping remaining calls.")
|
||||||
break
|
# Don't append the skipped call to processed_content
|
||||||
|
continue # Skip processing this call
|
||||||
|
|
||||||
tool_name = match.group(1)
|
|
||||||
args_str = match.group(2) or ""
|
|
||||||
full_match_str = match.group(0)
|
|
||||||
|
|
||||||
# We no longer need to handle Continue separately, as we'll continue if any tool calls exist
|
|
||||||
# Just track that a tool call was found
|
|
||||||
tool_calls_found = True
|
tool_calls_found = True
|
||||||
|
tool_name = None
|
||||||
# Extract parameters
|
|
||||||
params = {}
|
params = {}
|
||||||
suppressed_arg_values = ["..."] # Values to ignore during parsing
|
result_message = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for arg_match in args_pattern.finditer(args_str):
|
# Wrap the inner content to make it parseable as a function call
|
||||||
key = arg_match.group(1)
|
# Example: ToolName, key="value" becomes f(ToolName, key="value")
|
||||||
# Value can be in group 2 (double quotes), 3 (single quotes), or 4 (unquoted)
|
parse_str = f"f({inner_content})"
|
||||||
value = arg_match.group(2) or arg_match.group(3) or arg_match.group(4)
|
parsed_ast = ast.parse(parse_str)
|
||||||
|
|
||||||
# Check if the value is suppressed
|
# Validate AST structure
|
||||||
if value in suppressed_arg_values:
|
if not isinstance(parsed_ast, ast.Module) or not parsed_ast.body or not isinstance(parsed_ast.body[0], ast.Expr):
|
||||||
|
raise ValueError("Unexpected AST structure")
|
||||||
|
call_node = parsed_ast.body[0].value
|
||||||
|
if not isinstance(call_node, ast.Call):
|
||||||
|
raise ValueError("Expected a Call node")
|
||||||
|
|
||||||
|
# Extract tool name (should be the first positional argument)
|
||||||
|
if not call_node.args or not isinstance(call_node.args[0], ast.Name):
|
||||||
|
raise ValueError("Tool name not found or invalid")
|
||||||
|
tool_name = call_node.args[0].id
|
||||||
|
|
||||||
|
# Extract keyword arguments
|
||||||
|
for keyword in call_node.keywords:
|
||||||
|
key = keyword.arg
|
||||||
|
value_node = keyword.value
|
||||||
|
# Extract value based on AST node type
|
||||||
|
if isinstance(value_node, ast.Constant):
|
||||||
|
value = value_node.value
|
||||||
|
elif isinstance(value_node, ast.Name): # Handle unquoted values like True/False/None or variables (though variables are unlikely here)
|
||||||
|
value = value_node.id
|
||||||
|
# Add more types if needed (e.g., ast.List, ast.Dict)
|
||||||
|
else:
|
||||||
|
# Attempt to reconstruct the source for complex types, or raise error
|
||||||
|
try:
|
||||||
|
# Note: ast.unparse requires Python 3.9+
|
||||||
|
# If using older Python, might need a different approach or limit supported types
|
||||||
|
value = ast.unparse(value_node)
|
||||||
|
except AttributeError: # Handle case where ast.unparse is not available
|
||||||
|
raise ValueError(f"Unsupported argument type for key '{key}': {type(value_node)}")
|
||||||
|
except Exception as ue:
|
||||||
|
raise ValueError(f"Could not unparse value for key '{key}': {ue}")
|
||||||
|
|
||||||
|
|
||||||
|
# Check for suppressed values (e.g., "...")
|
||||||
|
suppressed_arg_values = ["..."]
|
||||||
|
if isinstance(value, str) and value in suppressed_arg_values:
|
||||||
self.io.tool_warning(f"Skipping suppressed argument value '{value}' for key '{key}' in tool '{tool_name}'")
|
self.io.tool_warning(f"Skipping suppressed argument value '{value}' for key '{key}' in tool '{tool_name}'")
|
||||||
continue # Skip this argument
|
continue
|
||||||
|
|
||||||
params[key] = value if value is not None else ""
|
params[key] = value
|
||||||
except Exception as e:
|
|
||||||
result_messages.append(f"[Result ({tool_name}): Error parsing arguments '{args_str}': {e}]")
|
|
||||||
# Remove the malformed call from the content
|
except (SyntaxError, ValueError) as e:
|
||||||
modified_content = modified_content.replace(full_match_str, "", 1)
|
result_message = f"Error parsing tool call '{inner_content}': {e}"
|
||||||
processed_indices.add((match.start(), match.end()))
|
self.io.tool_error(f"Failed to parse tool call: {full_match_str}\nError: {e}")
|
||||||
continue # Skip execution if args parsing failed
|
# Don't append the malformed call to processed_content
|
||||||
|
result_messages.append(f"[Result (Parse Error): {result_message}]")
|
||||||
|
continue # Skip execution
|
||||||
|
except Exception as e: # Catch any other unexpected parsing errors
|
||||||
|
result_message = f"Unexpected error parsing tool call '{inner_content}': {e}"
|
||||||
|
self.io.tool_error(f"Unexpected error during parsing: {full_match_str}\nError: {e}\n{traceback.format_exc()}")
|
||||||
|
result_messages.append(f"[Result (Parse Error): {result_message}]")
|
||||||
|
continue
|
||||||
|
|
||||||
# Execute the tool based on its name
|
# Execute the tool based on its name
|
||||||
result_message = None
|
|
||||||
try:
|
try:
|
||||||
# Normalize tool name for case-insensitive matching
|
# Normalize tool name for case-insensitive matching
|
||||||
norm_tool_name = tool_name.lower()
|
norm_tool_name = tool_name.lower()
|
||||||
|
@ -474,14 +558,18 @@ class NavigatorCoder(Coder):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result_message = f"Error executing {tool_name}: {str(e)}"
|
result_message = f"Error executing {tool_name}: {str(e)}"
|
||||||
self.io.tool_error(f"Error during {tool_name} execution: {e}")
|
self.io.tool_error(f"Error during {tool_name} execution: {e}\n{traceback.format_exc()}")
|
||||||
|
|
||||||
if result_message:
|
if result_message:
|
||||||
result_messages.append(f"[Result ({tool_name}): {result_message}]")
|
result_messages.append(f"[Result ({tool_name}): {result_message}]")
|
||||||
|
|
||||||
# Remove the processed tool call from the content
|
# Note: We don't add the tool call string back to processed_content
|
||||||
modified_content = modified_content.replace(full_match_str, "", 1)
|
|
||||||
processed_indices.add((match.start(), match.end()))
|
# Update internal counter
|
||||||
|
self.tool_call_count += call_count
|
||||||
|
|
||||||
|
# Return the content with tool calls removed
|
||||||
|
modified_content = processed_content
|
||||||
|
|
||||||
# Update internal counter
|
# Update internal counter
|
||||||
self.tool_call_count += call_count
|
self.tool_call_count += call_count
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue