mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-08 13:44:59 +00:00
Capture TreeSitter ranges for tools to use
This commit is contained in:
parent
d6e58ce063
commit
b51abd7fe7
4 changed files with 126 additions and 11 deletions
|
@ -27,11 +27,11 @@ from grep_ast.tsl import USING_TSL_PACK, get_language, get_parser # noqa: E402
|
||||||
|
|
||||||
# Define the Tag namedtuple with a default for specific_kind to maintain compatibility
|
# Define the Tag namedtuple with a default for specific_kind to maintain compatibility
|
||||||
# with cached entries that might have been created with the old definition
|
# with cached entries that might have been created with the old definition
|
||||||
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind")):
|
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind start_line end_line start_byte end_byte")):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None):
|
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None, start_line=None, end_line=None, start_byte=None, end_byte=None):
|
||||||
# Provide a default value for specific_kind to handle old cached objects
|
# Provide a default value for specific_kind to handle old cached objects
|
||||||
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind)
|
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind, start_line, end_line, start_byte, end_byte)
|
||||||
|
|
||||||
Tag = TagBase
|
Tag = TagBase
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ SQLITE_ERRORS = (sqlite3.OperationalError, sqlite3.DatabaseError, OSError)
|
||||||
|
|
||||||
CACHE_VERSION = 5
|
CACHE_VERSION = 5
|
||||||
if USING_TSL_PACK:
|
if USING_TSL_PACK:
|
||||||
CACHE_VERSION = 6
|
CACHE_VERSION = 7
|
||||||
|
|
||||||
|
|
||||||
class RepoMap:
|
class RepoMap:
|
||||||
|
@ -247,6 +247,51 @@ class RepoMap:
|
||||||
self.io.tool_warning(f"File not found error: {fname}")
|
self.io.tool_warning(f"File not found error: {fname}")
|
||||||
|
|
||||||
def get_tags(self, fname, rel_fname):
|
def get_tags(self, fname, rel_fname):
|
||||||
|
def get_symbol_definition_location(self, file_path, symbol_name):
|
||||||
|
"""
|
||||||
|
Finds the unique definition location (start/end line) for a symbol in a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): The relative path to the file.
|
||||||
|
symbol_name (str): The name of the symbol to find.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (start_line, end_line) (0-based) if a unique definition is found.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolError: If the symbol is not found, not unique, or not a definition.
|
||||||
|
"""
|
||||||
|
abs_path = self.io.root_abs_path(file_path) # Assuming io has this helper or similar
|
||||||
|
rel_path = self.get_rel_fname(abs_path) # Ensure we use consistent relative path
|
||||||
|
|
||||||
|
tags = self.get_tags(abs_path, rel_path)
|
||||||
|
if not tags:
|
||||||
|
raise ToolError(f"Symbol '{symbol_name}' not found in '{file_path}' (no tags).")
|
||||||
|
|
||||||
|
definitions = []
|
||||||
|
for tag in tags:
|
||||||
|
# Check if it's a definition and the name matches
|
||||||
|
if tag.kind == "def" and tag.name == symbol_name:
|
||||||
|
# Ensure we have valid location info
|
||||||
|
if tag.start_line is not None and tag.end_line is not None and tag.start_line >= 0:
|
||||||
|
definitions.append(tag)
|
||||||
|
|
||||||
|
if not definitions:
|
||||||
|
# Check if it exists as a non-definition tag
|
||||||
|
non_defs = [tag for tag in tags if tag.name == symbol_name and tag.kind != "def"]
|
||||||
|
if non_defs:
|
||||||
|
raise ToolError(f"Symbol '{symbol_name}' found in '{file_path}', but not as a unique definition (found as {non_defs[0].kind}).")
|
||||||
|
else:
|
||||||
|
raise ToolError(f"Symbol '{symbol_name}' definition not found in '{file_path}'.")
|
||||||
|
|
||||||
|
if len(definitions) > 1:
|
||||||
|
# Provide more context about ambiguity if possible
|
||||||
|
lines = sorted([d.start_line + 1 for d in definitions]) # 1-based for user message
|
||||||
|
raise ToolError(f"Symbol '{symbol_name}' is ambiguous in '{file_path}'. Found definitions on lines: {', '.join(map(str, lines))}.")
|
||||||
|
|
||||||
|
# Unique definition found
|
||||||
|
definition_tag = definitions[0]
|
||||||
|
return definition_tag.start_line, definition_tag.end_line
|
||||||
# Check if the file is in the cache and if the modification time has not changed
|
# Check if the file is in the cache and if the modification time has not changed
|
||||||
file_mtime = self.get_mtime(fname)
|
file_mtime = self.get_mtime(fname)
|
||||||
if file_mtime is None:
|
if file_mtime is None:
|
||||||
|
@ -345,7 +390,11 @@ class RepoMap:
|
||||||
name=node.text.decode("utf-8"),
|
name=node.text.decode("utf-8"),
|
||||||
kind=kind,
|
kind=kind,
|
||||||
specific_kind=specific_kind,
|
specific_kind=specific_kind,
|
||||||
line=node.start_point[0],
|
line=node.start_point[0], # Legacy line number
|
||||||
|
start_line=node.start_point[0],
|
||||||
|
end_line=node.end_point[0],
|
||||||
|
start_byte=node.start_byte,
|
||||||
|
end_byte=node.end_byte,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield result
|
yield result
|
||||||
|
@ -375,7 +424,11 @@ class RepoMap:
|
||||||
name=token,
|
name=token,
|
||||||
kind="ref",
|
kind="ref",
|
||||||
specific_kind="name", # Default for pygments fallback
|
specific_kind="name", # Default for pygments fallback
|
||||||
line=-1,
|
line=-1, # Pygments doesn't give precise locations easily
|
||||||
|
start_line=-1,
|
||||||
|
end_line=-1,
|
||||||
|
start_byte=-1,
|
||||||
|
end_byte=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_ranked_tags(
|
def get_ranked_tags(
|
||||||
|
@ -875,4 +928,4 @@ if __name__ == "__main__":
|
||||||
repo_map = rm.get_ranked_tags_map(chat_fnames, other_fnames)
|
repo_map = rm.get_ranked_tags_map(chat_fnames, other_fnames)
|
||||||
|
|
||||||
dump(len(repo_map))
|
dump(len(repo_map))
|
||||||
print(repo_map)
|
print(repo_map)
|
|
@ -30,7 +30,14 @@ def _execute_delete_block(coder, file_path, start_pattern, end_pattern=None, lin
|
||||||
|
|
||||||
# 3. Determine the end line, passing pattern_desc for better error messages
|
# 3. Determine the end line, passing pattern_desc for better error messages
|
||||||
start_line, end_line = determine_line_range(
|
start_line, end_line = determine_line_range(
|
||||||
lines, start_line_idx, end_pattern, line_count, pattern_desc=pattern_desc
|
coder=coder,
|
||||||
|
file_path=rel_path,
|
||||||
|
lines=lines,
|
||||||
|
start_pattern_line_index=start_line_idx,
|
||||||
|
end_pattern=end_pattern,
|
||||||
|
line_count=line_count,
|
||||||
|
target_symbol=None, # DeleteBlock uses patterns, not symbols
|
||||||
|
pattern_desc=pattern_desc
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Prepare the deletion
|
# 4. Prepare the deletion
|
||||||
|
|
|
@ -44,7 +44,14 @@ def _execute_indent_lines(coder, file_path, start_pattern, end_pattern=None, lin
|
||||||
|
|
||||||
# 3. Determine the end line
|
# 3. Determine the end line
|
||||||
start_line, end_line = determine_line_range(
|
start_line, end_line = determine_line_range(
|
||||||
lines, start_line_idx, end_pattern, line_count, pattern_desc=pattern_desc
|
coder=coder,
|
||||||
|
file_path=rel_path,
|
||||||
|
lines=lines,
|
||||||
|
start_pattern_line_index=start_line_idx,
|
||||||
|
end_pattern=end_pattern,
|
||||||
|
line_count=line_count,
|
||||||
|
target_symbol=None, # IndentLines uses patterns, not symbols
|
||||||
|
pattern_desc=pattern_desc
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Validate and prepare indentation
|
# 4. Validate and prepare indentation
|
||||||
|
|
|
@ -85,11 +85,59 @@ def select_occurrence_index(indices, occurrence, pattern_desc="Pattern"):
|
||||||
|
|
||||||
return indices[target_idx]
|
return indices[target_idx]
|
||||||
|
|
||||||
def determine_line_range(lines, start_pattern_line_index, end_pattern=None, line_count=None, pattern_desc="Block"):
|
def determine_line_range(
|
||||||
|
coder, # Added: Need coder to access repo_map
|
||||||
|
file_path, # Added: Need file_path for repo_map lookup
|
||||||
|
lines,
|
||||||
|
start_pattern_line_index=None, # Made optional
|
||||||
|
end_pattern=None,
|
||||||
|
line_count=None,
|
||||||
|
target_symbol=None, # Added: New parameter for symbol targeting
|
||||||
|
pattern_desc="Block",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Determines the end line index based on end_pattern or line_count.
|
Determines the end line index based on end_pattern or line_count.
|
||||||
Raises ToolError if end_pattern is not found or line_count is invalid.
|
Raises ToolError if end_pattern is not found or line_count is invalid.
|
||||||
"""
|
"""
|
||||||
|
# Parameter validation: Ensure only one targeting method is used
|
||||||
|
targeting_methods = [
|
||||||
|
target_symbol is not None,
|
||||||
|
start_pattern_line_index is not None,
|
||||||
|
# Note: line_count and end_pattern depend on start_pattern_line_index
|
||||||
|
]
|
||||||
|
if sum(targeting_methods) > 1:
|
||||||
|
raise ToolError("Cannot specify target_symbol along with start_pattern.")
|
||||||
|
if sum(targeting_methods) == 0:
|
||||||
|
raise ToolError("Must specify either target_symbol or start_pattern.") # Or line numbers for line-based tools, handled elsewhere
|
||||||
|
|
||||||
|
if target_symbol:
|
||||||
|
if end_pattern or line_count:
|
||||||
|
raise ToolError("Cannot specify end_pattern or line_count when using target_symbol.")
|
||||||
|
try:
|
||||||
|
# Use repo_map to find the symbol's definition range
|
||||||
|
start_line, end_line = coder.repo_map.get_symbol_definition_location(file_path, target_symbol)
|
||||||
|
return start_line, end_line
|
||||||
|
except AttributeError: # Use specific exception
|
||||||
|
# Check if repo_map exists and is initialized before accessing methods
|
||||||
|
if not hasattr(coder, 'repo_map') or coder.repo_map is None:
|
||||||
|
raise ToolError("RepoMap is not available or not initialized.")
|
||||||
|
# If repo_map exists, the error might be from get_symbol_definition_location itself
|
||||||
|
# Re-raise ToolErrors directly
|
||||||
|
raise
|
||||||
|
except ToolError as e:
|
||||||
|
# Propagate specific ToolErrors from repo_map (not found, ambiguous, etc.)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
# Catch other unexpected errors during symbol lookup
|
||||||
|
raise ToolError(f"Unexpected error looking up symbol '{target_symbol}': {e}")
|
||||||
|
|
||||||
|
# --- Existing logic for pattern/line_count based targeting ---
|
||||||
|
# Ensure start_pattern_line_index is provided if not using target_symbol
|
||||||
|
if start_pattern_line_index is None:
|
||||||
|
raise ToolError("Internal error: start_pattern_line_index is required when not using target_symbol.")
|
||||||
|
|
||||||
|
# Assign start_line here for the pattern-based logic path
|
||||||
|
start_line = start_pattern_line_index # Start of existing logic
|
||||||
start_line = start_pattern_line_index
|
start_line = start_pattern_line_index
|
||||||
end_line = -1
|
end_line = -1
|
||||||
|
|
||||||
|
@ -189,4 +237,4 @@ def format_tool_result(coder, tool_name, success_message, change_id=None, diff_s
|
||||||
# except ToolError as e:
|
# except ToolError as e:
|
||||||
# return handle_tool_error(coder, "MyTool", e, add_traceback=False) # Don't need traceback for ToolErrors
|
# return handle_tool_error(coder, "MyTool", e, add_traceback=False) # Don't need traceback for ToolErrors
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# return handle_tool_error(coder, "MyTool", e)
|
# return handle_tool_error(coder, "MyTool", e)
|
Loading…
Add table
Add a link
Reference in a new issue