Include a tree-sitter based outline of open files for the LLM

This commit is contained in:
Amar Sood (tekacs) 2025-04-12 05:48:25 -04:00
parent 17f06c86b2
commit 1e01482be6
2 changed files with 147 additions and 20 deletions

View file

@ -7,11 +7,22 @@ import random
import subprocess
import traceback
import platform
import ast
import re
import fnmatch
import os
import time
import random
import subprocess
import traceback
import platform
import locale
from datetime import datetime
from pathlib import Path
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import ParseError
# Add necessary imports if not already present
from collections import defaultdict
from .base_coder import Coder
from .editblock_coder import find_original_update_blocks, do_replace, find_similar_lines
@ -61,28 +72,104 @@ class NavigatorCoder(Coder):
# Enable enhanced context blocks by default
self.use_enhanced_context = True
def get_context_symbol_outline(self):
"""
Generate a symbol outline for files currently in context using Tree-sitter,
bypassing the cache for freshness.
"""
if not self.use_enhanced_context or not self.repo_map:
return None
try:
result = "<context name=\"symbol_outline\">\n"
result += "## Symbol Outline (Current Context)\n\n"
result += "Code definitions (classes, functions, methods, etc.) found in files currently in chat context.\n\n"
files_to_outline = list(self.abs_fnames) + list(self.abs_read_only_fnames)
if not files_to_outline:
result += "No files currently in context.\n"
result += "</context>"
return result
all_tags_by_file = defaultdict(list)
has_symbols = False
# Use repo_map which should be initialized in BaseCoder
if not self.repo_map:
self.io.tool_warning("RepoMap not initialized, cannot generate symbol outline.")
return None # Or return a message indicating repo map is unavailable
for abs_fname in sorted(files_to_outline):
rel_fname = self.get_rel_fname(abs_fname)
try:
# Call get_tags_raw directly to bypass cache and ensure freshness
tags = list(self.repo_map.get_tags_raw(abs_fname, rel_fname))
if tags:
all_tags_by_file[rel_fname].extend(tags)
has_symbols = True
except Exception as e:
self.io.tool_warning(f"Could not get symbols for {rel_fname}: {e}")
if not has_symbols:
result += "No symbols found in the current context files.\n"
else:
for rel_fname in sorted(all_tags_by_file.keys()):
tags = sorted(all_tags_by_file[rel_fname], key=lambda t: (t.line, t.name))
definition_tags = []
for tag in tags:
# Use specific_kind first if available, otherwise fall back to kind
kind_to_check = tag.specific_kind or tag.kind
# Check if the kind represents a definition using the set from RepoMap
if kind_to_check and kind_to_check.lower() in self.repo_map.definition_kinds:
definition_tags.append(tag)
if definition_tags:
result += f"### {rel_fname}\n"
# Simple list format for now, could be enhanced later (e.g., indentation for scope)
for tag in definition_tags:
# Display line number if available
line_info = f", line {tag.line + 1}" if tag.line >= 0 else ""
# Display the specific kind (which we checked)
kind_to_check = tag.specific_kind or tag.kind # Recalculate for safety
result += f"- {tag.name} ({kind_to_check}{line_info})\n"
result += "\n" # Add space between files
result += "</context>"
return result.strip() # Remove trailing newline if any
except Exception as e:
self.io.tool_error(f"Error generating symbol outline: {str(e)}")
# Optionally include traceback for debugging if verbose
# if self.verbose:
# self.io.tool_error(traceback.format_exc())
return None
def format_chat_chunks(self):
"""
Override parent's format_chat_chunks to include enhanced context blocks with a
Override parent's format_chat_chunks to include enhanced context blocks with a
cleaner, more hierarchical structure for better organization.
"""
# First get the normal chat chunks from the parent method
chunks = super().format_chat_chunks()
chunks = super().format_chat_chunks() # Calls BaseCoder's format_chat_chunks
# If enhanced context blocks are enabled, prepend them to the current messages
if self.use_enhanced_context:
# Create environment info context block
env_context = self.get_environment_info()
# Get directory structure
dir_structure = self.get_directory_structure()
# Get git status
git_status = self.get_git_status()
# Get current context summary
context_summary = self.get_context_summary()
# Get directory structure
dir_structure = self.get_directory_structure()
# Get git status
git_status = self.get_git_status()
# Get symbol outline for current context files
symbol_outline = self.get_context_symbol_outline() # New call
# Collect all context blocks that exist
context_blocks = []
if env_context:
@ -93,21 +180,24 @@ class NavigatorCoder(Coder):
context_blocks.append(dir_structure)
if git_status:
context_blocks.append(git_status)
# If we have any context blocks, prepend them to the current messages
if symbol_outline: # Add the new block if it was generated
context_blocks.append(symbol_outline)
# If we have any context blocks, prepend them to the system message
if context_blocks:
context_message = "\n\n".join(context_blocks)
# Prepend to system context but don't overwrite existing system content
if chunks.system:
# If we already have system messages, append our context to the first one
original_content = chunks.system[0]["content"]
# Ensure there's separation between our blocks and the original prompt
chunks.system[0]["content"] = context_message + "\n\n" + original_content
else:
# Otherwise, create a new system message
chunks.system = [dict(role="system", content=context_message)]
return chunks
def get_context_summary(self):
"""
Generate a summary of the current file context, including editable and read-only files,
@ -3317,4 +3407,4 @@ Just reply with fixed versions of the {blocks} above that failed to match.
return "\n".join(diff_lines_output)
except Exception as e:
return f"[Diff generation error: {e}]"
return f"[Diff generation error: {e}]"

View file

@ -25,15 +25,23 @@ from aider.utils import Spinner
warnings.simplefilter("ignore", category=FutureWarning)
from grep_ast.tsl import USING_TSL_PACK, get_language, get_parser # noqa: E402
Tag = namedtuple("Tag", "rel_fname fname line name kind".split())
# Define the Tag namedtuple with a default for specific_kind to maintain compatibility
# with cached entries that might have been created with the old definition
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind")):
__slots__ = ()
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None):
# Provide a default value for specific_kind to handle old cached objects
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind)
Tag = TagBase
SQLITE_ERRORS = (sqlite3.OperationalError, sqlite3.DatabaseError, OSError)
CACHE_VERSION = 3
CACHE_VERSION = 5
if USING_TSL_PACK:
CACHE_VERSION = 4
CACHE_VERSION = 6
class RepoMap:
@ -41,6 +49,17 @@ class RepoMap:
warned_files = set()
# Define kinds that typically represent definitions across languages
# Used by NavigatorCoder to filter tags for the symbol outline
definition_kinds = {
"class", "struct", "enum", "interface", "trait", # Structure definitions
"function", "method", "constructor", # Function/method definitions
"module", "namespace", # Module/namespace definitions
"constant", "variable", # Top-level/class variable definitions (consider refining)
"type", # Type definitions
# Add more based on tree-sitter queries if needed
}
def __init__(
self,
map_tokens=1024,
@ -242,10 +261,23 @@ class RepoMap:
if val is not None and val.get("mtime") == file_mtime:
try:
return self.TAGS_CACHE[cache_key]["data"]
# Get the cached data
data = self.TAGS_CACHE[cache_key]["data"]
# Let our Tag class handle compatibility with old cache formats
# No need for special handling as TagBase.__new__ will supply default specific_kind
return data
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
return self.TAGS_CACHE[cache_key]["data"]
except (TypeError, AttributeError) as e:
# If we hit an error related to missing fields in old cached Tag objects,
# force a cache refresh for this file
if self.verbose:
self.io.tool_warning(f"Cache format error for {fname}, refreshing: {e}")
# Return empty list to trigger cache refresh
return []
# miss!
data = list(self.get_tags_raw(fname, rel_fname))
@ -304,11 +336,15 @@ class RepoMap:
saw.add(kind)
# Extract specific kind from the tag, e.g., 'function' from 'name.definition.function'
specific_kind = tag.split('.')[-1] if '.' in tag else None
result = Tag(
rel_fname=rel_fname,
fname=fname,
name=node.text.decode("utf-8"),
kind=kind,
specific_kind=specific_kind,
line=node.start_point[0],
)
@ -338,6 +374,7 @@ class RepoMap:
fname=fname,
name=token,
kind="ref",
specific_kind="name", # Default for pygments fallback
line=-1,
)