Capture TreeSitter ranges for tools to use

This commit is contained in:
Amar Sood (tekacs) 2025-04-12 10:21:14 -04:00
parent d6e58ce063
commit b51abd7fe7
4 changed files with 126 additions and 11 deletions

View file

@ -27,11 +27,11 @@ from grep_ast.tsl import USING_TSL_PACK, get_language, get_parser # noqa: E402
# Define the Tag namedtuple with a default for specific_kind to maintain compatibility
# with cached entries that might have been created with the old definition
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind")):
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind start_line end_line start_byte end_byte")):
__slots__ = ()
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None):
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None, start_line=None, end_line=None, start_byte=None, end_byte=None):
# Provide a default value for specific_kind to handle old cached objects
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind)
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind, start_line, end_line, start_byte, end_byte)
Tag = TagBase
@ -41,7 +41,7 @@ SQLITE_ERRORS = (sqlite3.OperationalError, sqlite3.DatabaseError, OSError)
CACHE_VERSION = 5
if USING_TSL_PACK:
CACHE_VERSION = 6
CACHE_VERSION = 7
class RepoMap:
@ -247,6 +247,51 @@ class RepoMap:
self.io.tool_warning(f"File not found error: {fname}")
def get_tags(self, fname, rel_fname):
def get_symbol_definition_location(self, file_path, symbol_name):
"""
Finds the unique definition location (start/end line) for a symbol in a file.
Args:
file_path (str): The relative path to the file.
symbol_name (str): The name of the symbol to find.
Returns:
tuple: (start_line, end_line) (0-based) if a unique definition is found.
Raises:
ToolError: If the symbol is not found, not unique, or not a definition.
"""
abs_path = self.io.root_abs_path(file_path) # Assuming io has this helper or similar
rel_path = self.get_rel_fname(abs_path) # Ensure we use consistent relative path
tags = self.get_tags(abs_path, rel_path)
if not tags:
raise ToolError(f"Symbol '{symbol_name}' not found in '{file_path}' (no tags).")
definitions = []
for tag in tags:
# Check if it's a definition and the name matches
if tag.kind == "def" and tag.name == symbol_name:
# Ensure we have valid location info
if tag.start_line is not None and tag.end_line is not None and tag.start_line >= 0:
definitions.append(tag)
if not definitions:
# Check if it exists as a non-definition tag
non_defs = [tag for tag in tags if tag.name == symbol_name and tag.kind != "def"]
if non_defs:
raise ToolError(f"Symbol '{symbol_name}' found in '{file_path}', but not as a unique definition (found as {non_defs[0].kind}).")
else:
raise ToolError(f"Symbol '{symbol_name}' definition not found in '{file_path}'.")
if len(definitions) > 1:
# Provide more context about ambiguity if possible
lines = sorted([d.start_line + 1 for d in definitions]) # 1-based for user message
raise ToolError(f"Symbol '{symbol_name}' is ambiguous in '{file_path}'. Found definitions on lines: {', '.join(map(str, lines))}.")
# Unique definition found
definition_tag = definitions[0]
return definition_tag.start_line, definition_tag.end_line
# Check if the file is in the cache and if the modification time has not changed
file_mtime = self.get_mtime(fname)
if file_mtime is None:
@ -345,7 +390,11 @@ class RepoMap:
name=node.text.decode("utf-8"),
kind=kind,
specific_kind=specific_kind,
line=node.start_point[0],
line=node.start_point[0], # Legacy line number
start_line=node.start_point[0],
end_line=node.end_point[0],
start_byte=node.start_byte,
end_byte=node.end_byte,
)
yield result
@ -375,7 +424,11 @@ class RepoMap:
name=token,
kind="ref",
specific_kind="name", # Default for pygments fallback
line=-1,
line=-1, # Pygments doesn't give precise locations easily
start_line=-1,
end_line=-1,
start_byte=-1,
end_byte=-1,
)
def get_ranked_tags(

View file

@ -30,7 +30,14 @@ def _execute_delete_block(coder, file_path, start_pattern, end_pattern=None, lin
# 3. Determine the end line, passing pattern_desc for better error messages
start_line, end_line = determine_line_range(
lines, start_line_idx, end_pattern, line_count, pattern_desc=pattern_desc
coder=coder,
file_path=rel_path,
lines=lines,
start_pattern_line_index=start_line_idx,
end_pattern=end_pattern,
line_count=line_count,
target_symbol=None, # DeleteBlock uses patterns, not symbols
pattern_desc=pattern_desc
)
# 4. Prepare the deletion

View file

@ -44,7 +44,14 @@ def _execute_indent_lines(coder, file_path, start_pattern, end_pattern=None, lin
# 3. Determine the end line
start_line, end_line = determine_line_range(
lines, start_line_idx, end_pattern, line_count, pattern_desc=pattern_desc
coder=coder,
file_path=rel_path,
lines=lines,
start_pattern_line_index=start_line_idx,
end_pattern=end_pattern,
line_count=line_count,
target_symbol=None, # IndentLines uses patterns, not symbols
pattern_desc=pattern_desc
)
# 4. Validate and prepare indentation

View file

@ -85,11 +85,59 @@ def select_occurrence_index(indices, occurrence, pattern_desc="Pattern"):
return indices[target_idx]
def determine_line_range(lines, start_pattern_line_index, end_pattern=None, line_count=None, pattern_desc="Block"):
def determine_line_range(
coder, # Added: Need coder to access repo_map
file_path, # Added: Need file_path for repo_map lookup
lines,
start_pattern_line_index=None, # Made optional
end_pattern=None,
line_count=None,
target_symbol=None, # Added: New parameter for symbol targeting
pattern_desc="Block",
):
"""
Determines the end line index based on end_pattern or line_count.
Raises ToolError if end_pattern is not found or line_count is invalid.
"""
# Parameter validation: Ensure only one targeting method is used
targeting_methods = [
target_symbol is not None,
start_pattern_line_index is not None,
# Note: line_count and end_pattern depend on start_pattern_line_index
]
if sum(targeting_methods) > 1:
raise ToolError("Cannot specify target_symbol along with start_pattern.")
if sum(targeting_methods) == 0:
raise ToolError("Must specify either target_symbol or start_pattern.") # Or line numbers for line-based tools, handled elsewhere
if target_symbol:
if end_pattern or line_count:
raise ToolError("Cannot specify end_pattern or line_count when using target_symbol.")
try:
# Use repo_map to find the symbol's definition range
start_line, end_line = coder.repo_map.get_symbol_definition_location(file_path, target_symbol)
return start_line, end_line
except AttributeError: # Use specific exception
# Check if repo_map exists and is initialized before accessing methods
if not hasattr(coder, 'repo_map') or coder.repo_map is None:
raise ToolError("RepoMap is not available or not initialized.")
# If repo_map exists, the error might be from get_symbol_definition_location itself
# Re-raise ToolErrors directly
raise
except ToolError as e:
# Propagate specific ToolErrors from repo_map (not found, ambiguous, etc.)
raise e
except Exception as e:
# Catch other unexpected errors during symbol lookup
raise ToolError(f"Unexpected error looking up symbol '{target_symbol}': {e}")
# --- Existing logic for pattern/line_count based targeting ---
# Ensure start_pattern_line_index is provided if not using target_symbol
if start_pattern_line_index is None:
raise ToolError("Internal error: start_pattern_line_index is required when not using target_symbol.")
# Assign start_line here for the pattern-based logic path
start_line = start_pattern_line_index # Start of existing logic
start_line = start_pattern_line_index
end_line = -1