mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-08 13:44:59 +00:00
Capture TreeSitter ranges for tools to use
This commit is contained in:
parent
d6e58ce063
commit
b51abd7fe7
4 changed files with 126 additions and 11 deletions
|
@ -27,11 +27,11 @@ from grep_ast.tsl import USING_TSL_PACK, get_language, get_parser # noqa: E402
|
|||
|
||||
# Define the Tag namedtuple with a default for specific_kind to maintain compatibility
|
||||
# with cached entries that might have been created with the old definition
|
||||
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind")):
|
||||
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind start_line end_line start_byte end_byte")):
|
||||
__slots__ = ()
|
||||
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None):
|
||||
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None, start_line=None, end_line=None, start_byte=None, end_byte=None):
|
||||
# Provide a default value for specific_kind to handle old cached objects
|
||||
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind)
|
||||
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind, start_line, end_line, start_byte, end_byte)
|
||||
|
||||
Tag = TagBase
|
||||
|
||||
|
@ -41,7 +41,7 @@ SQLITE_ERRORS = (sqlite3.OperationalError, sqlite3.DatabaseError, OSError)
|
|||
|
||||
CACHE_VERSION = 5
|
||||
if USING_TSL_PACK:
|
||||
CACHE_VERSION = 6
|
||||
CACHE_VERSION = 7
|
||||
|
||||
|
||||
class RepoMap:
|
||||
|
@ -247,6 +247,51 @@ class RepoMap:
|
|||
self.io.tool_warning(f"File not found error: {fname}")
|
||||
|
||||
def get_tags(self, fname, rel_fname):
|
||||
def get_symbol_definition_location(self, file_path, symbol_name):
|
||||
"""
|
||||
Finds the unique definition location (start/end line) for a symbol in a file.
|
||||
|
||||
Args:
|
||||
file_path (str): The relative path to the file.
|
||||
symbol_name (str): The name of the symbol to find.
|
||||
|
||||
Returns:
|
||||
tuple: (start_line, end_line) (0-based) if a unique definition is found.
|
||||
|
||||
Raises:
|
||||
ToolError: If the symbol is not found, not unique, or not a definition.
|
||||
"""
|
||||
abs_path = self.io.root_abs_path(file_path) # Assuming io has this helper or similar
|
||||
rel_path = self.get_rel_fname(abs_path) # Ensure we use consistent relative path
|
||||
|
||||
tags = self.get_tags(abs_path, rel_path)
|
||||
if not tags:
|
||||
raise ToolError(f"Symbol '{symbol_name}' not found in '{file_path}' (no tags).")
|
||||
|
||||
definitions = []
|
||||
for tag in tags:
|
||||
# Check if it's a definition and the name matches
|
||||
if tag.kind == "def" and tag.name == symbol_name:
|
||||
# Ensure we have valid location info
|
||||
if tag.start_line is not None and tag.end_line is not None and tag.start_line >= 0:
|
||||
definitions.append(tag)
|
||||
|
||||
if not definitions:
|
||||
# Check if it exists as a non-definition tag
|
||||
non_defs = [tag for tag in tags if tag.name == symbol_name and tag.kind != "def"]
|
||||
if non_defs:
|
||||
raise ToolError(f"Symbol '{symbol_name}' found in '{file_path}', but not as a unique definition (found as {non_defs[0].kind}).")
|
||||
else:
|
||||
raise ToolError(f"Symbol '{symbol_name}' definition not found in '{file_path}'.")
|
||||
|
||||
if len(definitions) > 1:
|
||||
# Provide more context about ambiguity if possible
|
||||
lines = sorted([d.start_line + 1 for d in definitions]) # 1-based for user message
|
||||
raise ToolError(f"Symbol '{symbol_name}' is ambiguous in '{file_path}'. Found definitions on lines: {', '.join(map(str, lines))}.")
|
||||
|
||||
# Unique definition found
|
||||
definition_tag = definitions[0]
|
||||
return definition_tag.start_line, definition_tag.end_line
|
||||
# Check if the file is in the cache and if the modification time has not changed
|
||||
file_mtime = self.get_mtime(fname)
|
||||
if file_mtime is None:
|
||||
|
@ -345,7 +390,11 @@ class RepoMap:
|
|||
name=node.text.decode("utf-8"),
|
||||
kind=kind,
|
||||
specific_kind=specific_kind,
|
||||
line=node.start_point[0],
|
||||
line=node.start_point[0], # Legacy line number
|
||||
start_line=node.start_point[0],
|
||||
end_line=node.end_point[0],
|
||||
start_byte=node.start_byte,
|
||||
end_byte=node.end_byte,
|
||||
)
|
||||
|
||||
yield result
|
||||
|
@ -375,7 +424,11 @@ class RepoMap:
|
|||
name=token,
|
||||
kind="ref",
|
||||
specific_kind="name", # Default for pygments fallback
|
||||
line=-1,
|
||||
line=-1, # Pygments doesn't give precise locations easily
|
||||
start_line=-1,
|
||||
end_line=-1,
|
||||
start_byte=-1,
|
||||
end_byte=-1,
|
||||
)
|
||||
|
||||
def get_ranked_tags(
|
||||
|
|
|
@ -30,7 +30,14 @@ def _execute_delete_block(coder, file_path, start_pattern, end_pattern=None, lin
|
|||
|
||||
# 3. Determine the end line, passing pattern_desc for better error messages
|
||||
start_line, end_line = determine_line_range(
|
||||
lines, start_line_idx, end_pattern, line_count, pattern_desc=pattern_desc
|
||||
coder=coder,
|
||||
file_path=rel_path,
|
||||
lines=lines,
|
||||
start_pattern_line_index=start_line_idx,
|
||||
end_pattern=end_pattern,
|
||||
line_count=line_count,
|
||||
target_symbol=None, # DeleteBlock uses patterns, not symbols
|
||||
pattern_desc=pattern_desc
|
||||
)
|
||||
|
||||
# 4. Prepare the deletion
|
||||
|
|
|
@ -44,7 +44,14 @@ def _execute_indent_lines(coder, file_path, start_pattern, end_pattern=None, lin
|
|||
|
||||
# 3. Determine the end line
|
||||
start_line, end_line = determine_line_range(
|
||||
lines, start_line_idx, end_pattern, line_count, pattern_desc=pattern_desc
|
||||
coder=coder,
|
||||
file_path=rel_path,
|
||||
lines=lines,
|
||||
start_pattern_line_index=start_line_idx,
|
||||
end_pattern=end_pattern,
|
||||
line_count=line_count,
|
||||
target_symbol=None, # IndentLines uses patterns, not symbols
|
||||
pattern_desc=pattern_desc
|
||||
)
|
||||
|
||||
# 4. Validate and prepare indentation
|
||||
|
|
|
@ -85,11 +85,59 @@ def select_occurrence_index(indices, occurrence, pattern_desc="Pattern"):
|
|||
|
||||
return indices[target_idx]
|
||||
|
||||
def determine_line_range(lines, start_pattern_line_index, end_pattern=None, line_count=None, pattern_desc="Block"):
|
||||
def determine_line_range(
|
||||
coder, # Added: Need coder to access repo_map
|
||||
file_path, # Added: Need file_path for repo_map lookup
|
||||
lines,
|
||||
start_pattern_line_index=None, # Made optional
|
||||
end_pattern=None,
|
||||
line_count=None,
|
||||
target_symbol=None, # Added: New parameter for symbol targeting
|
||||
pattern_desc="Block",
|
||||
):
|
||||
"""
|
||||
Determines the end line index based on end_pattern or line_count.
|
||||
Raises ToolError if end_pattern is not found or line_count is invalid.
|
||||
"""
|
||||
# Parameter validation: Ensure only one targeting method is used
|
||||
targeting_methods = [
|
||||
target_symbol is not None,
|
||||
start_pattern_line_index is not None,
|
||||
# Note: line_count and end_pattern depend on start_pattern_line_index
|
||||
]
|
||||
if sum(targeting_methods) > 1:
|
||||
raise ToolError("Cannot specify target_symbol along with start_pattern.")
|
||||
if sum(targeting_methods) == 0:
|
||||
raise ToolError("Must specify either target_symbol or start_pattern.") # Or line numbers for line-based tools, handled elsewhere
|
||||
|
||||
if target_symbol:
|
||||
if end_pattern or line_count:
|
||||
raise ToolError("Cannot specify end_pattern or line_count when using target_symbol.")
|
||||
try:
|
||||
# Use repo_map to find the symbol's definition range
|
||||
start_line, end_line = coder.repo_map.get_symbol_definition_location(file_path, target_symbol)
|
||||
return start_line, end_line
|
||||
except AttributeError: # Use specific exception
|
||||
# Check if repo_map exists and is initialized before accessing methods
|
||||
if not hasattr(coder, 'repo_map') or coder.repo_map is None:
|
||||
raise ToolError("RepoMap is not available or not initialized.")
|
||||
# If repo_map exists, the error might be from get_symbol_definition_location itself
|
||||
# Re-raise ToolErrors directly
|
||||
raise
|
||||
except ToolError as e:
|
||||
# Propagate specific ToolErrors from repo_map (not found, ambiguous, etc.)
|
||||
raise e
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors during symbol lookup
|
||||
raise ToolError(f"Unexpected error looking up symbol '{target_symbol}': {e}")
|
||||
|
||||
# --- Existing logic for pattern/line_count based targeting ---
|
||||
# Ensure start_pattern_line_index is provided if not using target_symbol
|
||||
if start_pattern_line_index is None:
|
||||
raise ToolError("Internal error: start_pattern_line_index is required when not using target_symbol.")
|
||||
|
||||
# Assign start_line here for the pattern-based logic path
|
||||
start_line = start_pattern_line_index # Start of existing logic
|
||||
start_line = start_pattern_line_index
|
||||
end_line = -1
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue