mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-08 13:44:59 +00:00
Include a tree-sitter based outline of open files for the LLM
This commit is contained in:
parent
17f06c86b2
commit
1e01482be6
2 changed files with 147 additions and 20 deletions
|
@ -7,11 +7,22 @@ import random
|
|||
import subprocess
|
||||
import traceback
|
||||
import platform
|
||||
import ast
|
||||
import re
|
||||
import fnmatch
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import subprocess
|
||||
import traceback
|
||||
import platform
|
||||
import locale
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.etree.ElementTree import ParseError
|
||||
# Add necessary imports if not already present
|
||||
from collections import defaultdict
|
||||
|
||||
from .base_coder import Coder
|
||||
from .editblock_coder import find_original_update_blocks, do_replace, find_similar_lines
|
||||
|
@ -61,28 +72,104 @@ class NavigatorCoder(Coder):
|
|||
# Enable enhanced context blocks by default
|
||||
self.use_enhanced_context = True
|
||||
|
||||
def get_context_symbol_outline(self):
|
||||
"""
|
||||
Generate a symbol outline for files currently in context using Tree-sitter,
|
||||
bypassing the cache for freshness.
|
||||
"""
|
||||
if not self.use_enhanced_context or not self.repo_map:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = "<context name=\"symbol_outline\">\n"
|
||||
result += "## Symbol Outline (Current Context)\n\n"
|
||||
result += "Code definitions (classes, functions, methods, etc.) found in files currently in chat context.\n\n"
|
||||
|
||||
files_to_outline = list(self.abs_fnames) + list(self.abs_read_only_fnames)
|
||||
if not files_to_outline:
|
||||
result += "No files currently in context.\n"
|
||||
result += "</context>"
|
||||
return result
|
||||
|
||||
all_tags_by_file = defaultdict(list)
|
||||
has_symbols = False
|
||||
|
||||
# Use repo_map which should be initialized in BaseCoder
|
||||
if not self.repo_map:
|
||||
self.io.tool_warning("RepoMap not initialized, cannot generate symbol outline.")
|
||||
return None # Or return a message indicating repo map is unavailable
|
||||
|
||||
for abs_fname in sorted(files_to_outline):
|
||||
rel_fname = self.get_rel_fname(abs_fname)
|
||||
try:
|
||||
# Call get_tags_raw directly to bypass cache and ensure freshness
|
||||
tags = list(self.repo_map.get_tags_raw(abs_fname, rel_fname))
|
||||
if tags:
|
||||
all_tags_by_file[rel_fname].extend(tags)
|
||||
has_symbols = True
|
||||
except Exception as e:
|
||||
self.io.tool_warning(f"Could not get symbols for {rel_fname}: {e}")
|
||||
|
||||
if not has_symbols:
|
||||
result += "No symbols found in the current context files.\n"
|
||||
else:
|
||||
for rel_fname in sorted(all_tags_by_file.keys()):
|
||||
tags = sorted(all_tags_by_file[rel_fname], key=lambda t: (t.line, t.name))
|
||||
|
||||
definition_tags = []
|
||||
for tag in tags:
|
||||
# Use specific_kind first if available, otherwise fall back to kind
|
||||
kind_to_check = tag.specific_kind or tag.kind
|
||||
# Check if the kind represents a definition using the set from RepoMap
|
||||
if kind_to_check and kind_to_check.lower() in self.repo_map.definition_kinds:
|
||||
definition_tags.append(tag)
|
||||
|
||||
if definition_tags:
|
||||
result += f"### {rel_fname}\n"
|
||||
# Simple list format for now, could be enhanced later (e.g., indentation for scope)
|
||||
for tag in definition_tags:
|
||||
# Display line number if available
|
||||
line_info = f", line {tag.line + 1}" if tag.line >= 0 else ""
|
||||
# Display the specific kind (which we checked)
|
||||
kind_to_check = tag.specific_kind or tag.kind # Recalculate for safety
|
||||
result += f"- {tag.name} ({kind_to_check}{line_info})\n"
|
||||
result += "\n" # Add space between files
|
||||
|
||||
result += "</context>"
|
||||
return result.strip() # Remove trailing newline if any
|
||||
|
||||
except Exception as e:
|
||||
self.io.tool_error(f"Error generating symbol outline: {str(e)}")
|
||||
# Optionally include traceback for debugging if verbose
|
||||
# if self.verbose:
|
||||
# self.io.tool_error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def format_chat_chunks(self):
|
||||
"""
|
||||
Override parent's format_chat_chunks to include enhanced context blocks with a
|
||||
Override parent's format_chat_chunks to include enhanced context blocks with a
|
||||
cleaner, more hierarchical structure for better organization.
|
||||
"""
|
||||
# First get the normal chat chunks from the parent method
|
||||
chunks = super().format_chat_chunks()
|
||||
|
||||
chunks = super().format_chat_chunks() # Calls BaseCoder's format_chat_chunks
|
||||
|
||||
# If enhanced context blocks are enabled, prepend them to the current messages
|
||||
if self.use_enhanced_context:
|
||||
# Create environment info context block
|
||||
env_context = self.get_environment_info()
|
||||
|
||||
# Get directory structure
|
||||
dir_structure = self.get_directory_structure()
|
||||
|
||||
# Get git status
|
||||
git_status = self.get_git_status()
|
||||
|
||||
|
||||
# Get current context summary
|
||||
context_summary = self.get_context_summary()
|
||||
|
||||
|
||||
# Get directory structure
|
||||
dir_structure = self.get_directory_structure()
|
||||
|
||||
# Get git status
|
||||
git_status = self.get_git_status()
|
||||
|
||||
# Get symbol outline for current context files
|
||||
symbol_outline = self.get_context_symbol_outline() # New call
|
||||
|
||||
# Collect all context blocks that exist
|
||||
context_blocks = []
|
||||
if env_context:
|
||||
|
@ -93,21 +180,24 @@ class NavigatorCoder(Coder):
|
|||
context_blocks.append(dir_structure)
|
||||
if git_status:
|
||||
context_blocks.append(git_status)
|
||||
|
||||
# If we have any context blocks, prepend them to the current messages
|
||||
if symbol_outline: # Add the new block if it was generated
|
||||
context_blocks.append(symbol_outline)
|
||||
|
||||
# If we have any context blocks, prepend them to the system message
|
||||
if context_blocks:
|
||||
context_message = "\n\n".join(context_blocks)
|
||||
# Prepend to system context but don't overwrite existing system content
|
||||
if chunks.system:
|
||||
# If we already have system messages, append our context to the first one
|
||||
original_content = chunks.system[0]["content"]
|
||||
# Ensure there's separation between our blocks and the original prompt
|
||||
chunks.system[0]["content"] = context_message + "\n\n" + original_content
|
||||
else:
|
||||
# Otherwise, create a new system message
|
||||
chunks.system = [dict(role="system", content=context_message)]
|
||||
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def get_context_summary(self):
|
||||
"""
|
||||
Generate a summary of the current file context, including editable and read-only files,
|
||||
|
@ -3317,4 +3407,4 @@ Just reply with fixed versions of the {blocks} above that failed to match.
|
|||
|
||||
return "\n".join(diff_lines_output)
|
||||
except Exception as e:
|
||||
return f"[Diff generation error: {e}]"
|
||||
return f"[Diff generation error: {e}]"
|
||||
|
|
|
@ -25,15 +25,23 @@ from aider.utils import Spinner
|
|||
warnings.simplefilter("ignore", category=FutureWarning)
|
||||
from grep_ast.tsl import USING_TSL_PACK, get_language, get_parser # noqa: E402
|
||||
|
||||
Tag = namedtuple("Tag", "rel_fname fname line name kind".split())
|
||||
# Define the Tag namedtuple with a default for specific_kind to maintain compatibility
|
||||
# with cached entries that might have been created with the old definition
|
||||
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind")):
|
||||
__slots__ = ()
|
||||
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None):
|
||||
# Provide a default value for specific_kind to handle old cached objects
|
||||
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind)
|
||||
|
||||
Tag = TagBase
|
||||
|
||||
|
||||
SQLITE_ERRORS = (sqlite3.OperationalError, sqlite3.DatabaseError, OSError)
|
||||
|
||||
|
||||
CACHE_VERSION = 3
|
||||
CACHE_VERSION = 5
|
||||
if USING_TSL_PACK:
|
||||
CACHE_VERSION = 4
|
||||
CACHE_VERSION = 6
|
||||
|
||||
|
||||
class RepoMap:
|
||||
|
@ -41,6 +49,17 @@ class RepoMap:
|
|||
|
||||
warned_files = set()
|
||||
|
||||
# Define kinds that typically represent definitions across languages
|
||||
# Used by NavigatorCoder to filter tags for the symbol outline
|
||||
definition_kinds = {
|
||||
"class", "struct", "enum", "interface", "trait", # Structure definitions
|
||||
"function", "method", "constructor", # Function/method definitions
|
||||
"module", "namespace", # Module/namespace definitions
|
||||
"constant", "variable", # Top-level/class variable definitions (consider refining)
|
||||
"type", # Type definitions
|
||||
# Add more based on tree-sitter queries if needed
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
map_tokens=1024,
|
||||
|
@ -242,10 +261,23 @@ class RepoMap:
|
|||
|
||||
if val is not None and val.get("mtime") == file_mtime:
|
||||
try:
|
||||
return self.TAGS_CACHE[cache_key]["data"]
|
||||
# Get the cached data
|
||||
data = self.TAGS_CACHE[cache_key]["data"]
|
||||
|
||||
# Let our Tag class handle compatibility with old cache formats
|
||||
# No need for special handling as TagBase.__new__ will supply default specific_kind
|
||||
|
||||
return data
|
||||
except SQLITE_ERRORS as e:
|
||||
self.tags_cache_error(e)
|
||||
return self.TAGS_CACHE[cache_key]["data"]
|
||||
except (TypeError, AttributeError) as e:
|
||||
# If we hit an error related to missing fields in old cached Tag objects,
|
||||
# force a cache refresh for this file
|
||||
if self.verbose:
|
||||
self.io.tool_warning(f"Cache format error for {fname}, refreshing: {e}")
|
||||
# Return empty list to trigger cache refresh
|
||||
return []
|
||||
|
||||
# miss!
|
||||
data = list(self.get_tags_raw(fname, rel_fname))
|
||||
|
@ -304,11 +336,15 @@ class RepoMap:
|
|||
|
||||
saw.add(kind)
|
||||
|
||||
# Extract specific kind from the tag, e.g., 'function' from 'name.definition.function'
|
||||
specific_kind = tag.split('.')[-1] if '.' in tag else None
|
||||
|
||||
result = Tag(
|
||||
rel_fname=rel_fname,
|
||||
fname=fname,
|
||||
name=node.text.decode("utf-8"),
|
||||
kind=kind,
|
||||
specific_kind=specific_kind,
|
||||
line=node.start_point[0],
|
||||
)
|
||||
|
||||
|
@ -338,6 +374,7 @@ class RepoMap:
|
|||
fname=fname,
|
||||
name=token,
|
||||
kind="ref",
|
||||
specific_kind="name", # Default for pygments fallback
|
||||
line=-1,
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue