Use a parens parser and then Python's ast.parse to parse tool calls robustly

This commit is contained in:
Amar Sood (tekacs) 2025-04-11 13:20:34 -04:00
parent 4339c73774
commit 765002d486

View file

@ -1,3 +1,4 @@
import ast
import re import re
import fnmatch import fnmatch
import os import os
@ -350,66 +351,149 @@ class NavigatorCoder(Coder):
call_count = 0 call_count = 0
max_calls = self.max_tool_calls max_calls = self.max_tool_calls
# Regex to find tool calls: [tool_call(name, key=value, key="value", ...)] # Find tool calls using a more robust method
# It captures the tool name and the arguments string. processed_content = ""
# It handles quoted and unquoted values. last_index = 0
tool_call_pattern = re.compile( start_marker = "[tool_call("
r"\[tool_call\(\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*" # Tool name end_marker = "]" # The parenthesis balancing finds the ')', we just need the final ']'
r"(?:,\s*(.*?))?\s*\)\]" # Optional arguments string (non-greedy)
)
# Regex to parse key=value pairs within the arguments string while True:
# Handles key=value, key="value", key='value' start_pos = content.find(start_marker, last_index)
# Allows values to contain commas if quoted if start_pos == -1:
args_pattern = re.compile( processed_content += content[last_index:]
r"([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*" # Key break
r"(?:\"(.*?)\"|\'(.*?)\'|([^,\s\'\"]*))" # Value (quoted or unquoted)
)
processed_indices = set() # Keep track of processed match ranges # Append content before the tool call
processed_content += content[last_index:start_pos]
for match in tool_call_pattern.finditer(content): scan_start_pos = start_pos + len(start_marker)
# Skip overlapping matches if a previous match already covered this area paren_level = 1
if any(match.start() >= start and match.end() <= end for start, end in processed_indices): in_single_quotes = False
in_double_quotes = False
escaped = False
end_paren_pos = -1
# Scan to find the matching closing parenthesis, respecting quotes
for i in range(scan_start_pos, len(content)):
char = content[i]
if escaped:
escaped = False
elif char == '\\':
escaped = True
elif char == "'" and not in_double_quotes:
in_single_quotes = not in_single_quotes
elif char == '"' and not in_single_quotes:
in_double_quotes = not in_double_quotes
elif char == '(' and not in_single_quotes and not in_double_quotes:
paren_level += 1
elif char == ')' and not in_single_quotes and not in_double_quotes:
paren_level -= 1
if paren_level == 0:
end_paren_pos = i
break
# Check for the end marker after the closing parenthesis, skipping whitespace
expected_end_marker_start = end_paren_pos + 1
actual_end_marker_start = -1
end_marker_found = False
if end_paren_pos != -1: # Only search if we found a closing parenthesis
for j in range(expected_end_marker_start, len(content)):
if not content[j].isspace():
actual_end_marker_start = j
# Check if the found character is the end marker ']'
if content[actual_end_marker_start] == end_marker:
end_marker_found = True
break # Stop searching after first non-whitespace char
if not end_marker_found:
# Malformed call: couldn't find matching ')' or the subsequent ']'
self.io.tool_warning(f"Malformed tool call starting at index {start_pos}. Skipping (end_paren_pos={end_paren_pos}, end_marker_found={end_marker_found}).")
# Append the start marker itself to processed content so it's not lost
processed_content += start_marker
last_index = scan_start_pos # Continue searching after the marker
continue continue
# Found a potential tool call
# Adjust full_match_str and last_index based on the actual end marker ']' position
full_match_str = content[start_pos : actual_end_marker_start + 1] # End marker ']' is 1 char
inner_content = content[scan_start_pos:end_paren_pos].strip()
last_index = actual_end_marker_start + 1 # Move past the processed call (including ']')
call_count += 1 call_count += 1
if call_count > max_calls: if call_count > max_calls:
self.io.tool_warning(f"Exceeded maximum tool calls ({max_calls}). Skipping remaining calls.") self.io.tool_warning(f"Exceeded maximum tool calls ({max_calls}). Skipping remaining calls.")
break # Don't append the skipped call to processed_content
continue # Skip processing this call
tool_name = match.group(1)
args_str = match.group(2) or ""
full_match_str = match.group(0)
# We no longer need to handle Continue separately, as we'll continue if any tool calls exist
# Just track that a tool call was found
tool_calls_found = True tool_calls_found = True
tool_name = None
# Extract parameters
params = {} params = {}
suppressed_arg_values = ["..."] # Values to ignore during parsing result_message = None
try: try:
for arg_match in args_pattern.finditer(args_str): # Wrap the inner content to make it parseable as a function call
key = arg_match.group(1) # Example: ToolName, key="value" becomes f(ToolName, key="value")
# Value can be in group 2 (double quotes), 3 (single quotes), or 4 (unquoted) parse_str = f"f({inner_content})"
value = arg_match.group(2) or arg_match.group(3) or arg_match.group(4) parsed_ast = ast.parse(parse_str)
# Check if the value is suppressed # Validate AST structure
if value in suppressed_arg_values: if not isinstance(parsed_ast, ast.Module) or not parsed_ast.body or not isinstance(parsed_ast.body[0], ast.Expr):
raise ValueError("Unexpected AST structure")
call_node = parsed_ast.body[0].value
if not isinstance(call_node, ast.Call):
raise ValueError("Expected a Call node")
# Extract tool name (should be the first positional argument)
if not call_node.args or not isinstance(call_node.args[0], ast.Name):
raise ValueError("Tool name not found or invalid")
tool_name = call_node.args[0].id
# Extract keyword arguments
for keyword in call_node.keywords:
key = keyword.arg
value_node = keyword.value
# Extract value based on AST node type
if isinstance(value_node, ast.Constant):
value = value_node.value
elif isinstance(value_node, ast.Name): # Handle unquoted values like True/False/None or variables (though variables are unlikely here)
value = value_node.id
# Add more types if needed (e.g., ast.List, ast.Dict)
else:
# Attempt to reconstruct the source for complex types, or raise error
try:
# Note: ast.unparse requires Python 3.9+
# If using older Python, might need a different approach or limit supported types
value = ast.unparse(value_node)
except AttributeError: # Handle case where ast.unparse is not available
raise ValueError(f"Unsupported argument type for key '{key}': {type(value_node)}")
except Exception as ue:
raise ValueError(f"Could not unparse value for key '{key}': {ue}")
# Check for suppressed values (e.g., "...")
suppressed_arg_values = ["..."]
if isinstance(value, str) and value in suppressed_arg_values:
self.io.tool_warning(f"Skipping suppressed argument value '{value}' for key '{key}' in tool '{tool_name}'") self.io.tool_warning(f"Skipping suppressed argument value '{value}' for key '{key}' in tool '{tool_name}'")
continue # Skip this argument continue
params[key] = value if value is not None else "" params[key] = value
except Exception as e:
result_messages.append(f"[Result ({tool_name}): Error parsing arguments '{args_str}': {e}]")
# Remove the malformed call from the content except (SyntaxError, ValueError) as e:
modified_content = modified_content.replace(full_match_str, "", 1) result_message = f"Error parsing tool call '{inner_content}': {e}"
processed_indices.add((match.start(), match.end())) self.io.tool_error(f"Failed to parse tool call: {full_match_str}\nError: {e}")
continue # Skip execution if args parsing failed # Don't append the malformed call to processed_content
result_messages.append(f"[Result (Parse Error): {result_message}]")
continue # Skip execution
except Exception as e: # Catch any other unexpected parsing errors
result_message = f"Unexpected error parsing tool call '{inner_content}': {e}"
self.io.tool_error(f"Unexpected error during parsing: {full_match_str}\nError: {e}\n{traceback.format_exc()}")
result_messages.append(f"[Result (Parse Error): {result_message}]")
continue
# Execute the tool based on its name # Execute the tool based on its name
result_message = None
try: try:
# Normalize tool name for case-insensitive matching # Normalize tool name for case-insensitive matching
norm_tool_name = tool_name.lower() norm_tool_name = tool_name.lower()
@ -474,14 +558,18 @@ class NavigatorCoder(Coder):
except Exception as e: except Exception as e:
result_message = f"Error executing {tool_name}: {str(e)}" result_message = f"Error executing {tool_name}: {str(e)}"
self.io.tool_error(f"Error during {tool_name} execution: {e}") self.io.tool_error(f"Error during {tool_name} execution: {e}\n{traceback.format_exc()}")
if result_message: if result_message:
result_messages.append(f"[Result ({tool_name}): {result_message}]") result_messages.append(f"[Result ({tool_name}): {result_message}]")
# Remove the processed tool call from the content # Note: We don't add the tool call string back to processed_content
modified_content = modified_content.replace(full_match_str, "", 1)
processed_indices.add((match.start(), match.end())) # Update internal counter
self.tool_call_count += call_count
# Return the content with tool calls removed
modified_content = processed_content
# Update internal counter # Update internal counter
self.tool_call_count += call_count self.tool_call_count += call_count