mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-08 21:55:00 +00:00
Include a tree-sitter based outline of open files for the LLM
This commit is contained in:
parent
17f06c86b2
commit
1e01482be6
2 changed files with 147 additions and 20 deletions
|
@ -7,11 +7,22 @@ import random
|
||||||
import subprocess
|
import subprocess
|
||||||
import traceback
|
import traceback
|
||||||
import platform
|
import platform
|
||||||
|
import ast
|
||||||
|
import re
|
||||||
|
import fnmatch
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import subprocess
|
||||||
|
import traceback
|
||||||
|
import platform
|
||||||
import locale
|
import locale
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from xml.etree.ElementTree import ParseError
|
from xml.etree.ElementTree import ParseError
|
||||||
|
# Add necessary imports if not already present
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from .base_coder import Coder
|
from .base_coder import Coder
|
||||||
from .editblock_coder import find_original_update_blocks, do_replace, find_similar_lines
|
from .editblock_coder import find_original_update_blocks, do_replace, find_similar_lines
|
||||||
|
@ -61,28 +72,104 @@ class NavigatorCoder(Coder):
|
||||||
# Enable enhanced context blocks by default
|
# Enable enhanced context blocks by default
|
||||||
self.use_enhanced_context = True
|
self.use_enhanced_context = True
|
||||||
|
|
||||||
|
def get_context_symbol_outline(self):
|
||||||
|
"""
|
||||||
|
Generate a symbol outline for files currently in context using Tree-sitter,
|
||||||
|
bypassing the cache for freshness.
|
||||||
|
"""
|
||||||
|
if not self.use_enhanced_context or not self.repo_map:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = "<context name=\"symbol_outline\">\n"
|
||||||
|
result += "## Symbol Outline (Current Context)\n\n"
|
||||||
|
result += "Code definitions (classes, functions, methods, etc.) found in files currently in chat context.\n\n"
|
||||||
|
|
||||||
|
files_to_outline = list(self.abs_fnames) + list(self.abs_read_only_fnames)
|
||||||
|
if not files_to_outline:
|
||||||
|
result += "No files currently in context.\n"
|
||||||
|
result += "</context>"
|
||||||
|
return result
|
||||||
|
|
||||||
|
all_tags_by_file = defaultdict(list)
|
||||||
|
has_symbols = False
|
||||||
|
|
||||||
|
# Use repo_map which should be initialized in BaseCoder
|
||||||
|
if not self.repo_map:
|
||||||
|
self.io.tool_warning("RepoMap not initialized, cannot generate symbol outline.")
|
||||||
|
return None # Or return a message indicating repo map is unavailable
|
||||||
|
|
||||||
|
for abs_fname in sorted(files_to_outline):
|
||||||
|
rel_fname = self.get_rel_fname(abs_fname)
|
||||||
|
try:
|
||||||
|
# Call get_tags_raw directly to bypass cache and ensure freshness
|
||||||
|
tags = list(self.repo_map.get_tags_raw(abs_fname, rel_fname))
|
||||||
|
if tags:
|
||||||
|
all_tags_by_file[rel_fname].extend(tags)
|
||||||
|
has_symbols = True
|
||||||
|
except Exception as e:
|
||||||
|
self.io.tool_warning(f"Could not get symbols for {rel_fname}: {e}")
|
||||||
|
|
||||||
|
if not has_symbols:
|
||||||
|
result += "No symbols found in the current context files.\n"
|
||||||
|
else:
|
||||||
|
for rel_fname in sorted(all_tags_by_file.keys()):
|
||||||
|
tags = sorted(all_tags_by_file[rel_fname], key=lambda t: (t.line, t.name))
|
||||||
|
|
||||||
|
definition_tags = []
|
||||||
|
for tag in tags:
|
||||||
|
# Use specific_kind first if available, otherwise fall back to kind
|
||||||
|
kind_to_check = tag.specific_kind or tag.kind
|
||||||
|
# Check if the kind represents a definition using the set from RepoMap
|
||||||
|
if kind_to_check and kind_to_check.lower() in self.repo_map.definition_kinds:
|
||||||
|
definition_tags.append(tag)
|
||||||
|
|
||||||
|
if definition_tags:
|
||||||
|
result += f"### {rel_fname}\n"
|
||||||
|
# Simple list format for now, could be enhanced later (e.g., indentation for scope)
|
||||||
|
for tag in definition_tags:
|
||||||
|
# Display line number if available
|
||||||
|
line_info = f", line {tag.line + 1}" if tag.line >= 0 else ""
|
||||||
|
# Display the specific kind (which we checked)
|
||||||
|
kind_to_check = tag.specific_kind or tag.kind # Recalculate for safety
|
||||||
|
result += f"- {tag.name} ({kind_to_check}{line_info})\n"
|
||||||
|
result += "\n" # Add space between files
|
||||||
|
|
||||||
|
result += "</context>"
|
||||||
|
return result.strip() # Remove trailing newline if any
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.io.tool_error(f"Error generating symbol outline: {str(e)}")
|
||||||
|
# Optionally include traceback for debugging if verbose
|
||||||
|
# if self.verbose:
|
||||||
|
# self.io.tool_error(traceback.format_exc())
|
||||||
|
return None
|
||||||
|
|
||||||
def format_chat_chunks(self):
|
def format_chat_chunks(self):
|
||||||
"""
|
"""
|
||||||
Override parent's format_chat_chunks to include enhanced context blocks with a
|
Override parent's format_chat_chunks to include enhanced context blocks with a
|
||||||
cleaner, more hierarchical structure for better organization.
|
cleaner, more hierarchical structure for better organization.
|
||||||
"""
|
"""
|
||||||
# First get the normal chat chunks from the parent method
|
# First get the normal chat chunks from the parent method
|
||||||
chunks = super().format_chat_chunks()
|
chunks = super().format_chat_chunks() # Calls BaseCoder's format_chat_chunks
|
||||||
|
|
||||||
# If enhanced context blocks are enabled, prepend them to the current messages
|
# If enhanced context blocks are enabled, prepend them to the current messages
|
||||||
if self.use_enhanced_context:
|
if self.use_enhanced_context:
|
||||||
# Create environment info context block
|
# Create environment info context block
|
||||||
env_context = self.get_environment_info()
|
env_context = self.get_environment_info()
|
||||||
|
|
||||||
# Get directory structure
|
|
||||||
dir_structure = self.get_directory_structure()
|
|
||||||
|
|
||||||
# Get git status
|
|
||||||
git_status = self.get_git_status()
|
|
||||||
|
|
||||||
# Get current context summary
|
# Get current context summary
|
||||||
context_summary = self.get_context_summary()
|
context_summary = self.get_context_summary()
|
||||||
|
|
||||||
|
# Get directory structure
|
||||||
|
dir_structure = self.get_directory_structure()
|
||||||
|
|
||||||
|
# Get git status
|
||||||
|
git_status = self.get_git_status()
|
||||||
|
|
||||||
|
# Get symbol outline for current context files
|
||||||
|
symbol_outline = self.get_context_symbol_outline() # New call
|
||||||
|
|
||||||
# Collect all context blocks that exist
|
# Collect all context blocks that exist
|
||||||
context_blocks = []
|
context_blocks = []
|
||||||
if env_context:
|
if env_context:
|
||||||
|
@ -93,21 +180,24 @@ class NavigatorCoder(Coder):
|
||||||
context_blocks.append(dir_structure)
|
context_blocks.append(dir_structure)
|
||||||
if git_status:
|
if git_status:
|
||||||
context_blocks.append(git_status)
|
context_blocks.append(git_status)
|
||||||
|
if symbol_outline: # Add the new block if it was generated
|
||||||
# If we have any context blocks, prepend them to the current messages
|
context_blocks.append(symbol_outline)
|
||||||
|
|
||||||
|
# If we have any context blocks, prepend them to the system message
|
||||||
if context_blocks:
|
if context_blocks:
|
||||||
context_message = "\n\n".join(context_blocks)
|
context_message = "\n\n".join(context_blocks)
|
||||||
# Prepend to system context but don't overwrite existing system content
|
# Prepend to system context but don't overwrite existing system content
|
||||||
if chunks.system:
|
if chunks.system:
|
||||||
# If we already have system messages, append our context to the first one
|
# If we already have system messages, append our context to the first one
|
||||||
original_content = chunks.system[0]["content"]
|
original_content = chunks.system[0]["content"]
|
||||||
|
# Ensure there's separation between our blocks and the original prompt
|
||||||
chunks.system[0]["content"] = context_message + "\n\n" + original_content
|
chunks.system[0]["content"] = context_message + "\n\n" + original_content
|
||||||
else:
|
else:
|
||||||
# Otherwise, create a new system message
|
# Otherwise, create a new system message
|
||||||
chunks.system = [dict(role="system", content=context_message)]
|
chunks.system = [dict(role="system", content=context_message)]
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def get_context_summary(self):
|
def get_context_summary(self):
|
||||||
"""
|
"""
|
||||||
Generate a summary of the current file context, including editable and read-only files,
|
Generate a summary of the current file context, including editable and read-only files,
|
||||||
|
@ -3317,4 +3407,4 @@ Just reply with fixed versions of the {blocks} above that failed to match.
|
||||||
|
|
||||||
return "\n".join(diff_lines_output)
|
return "\n".join(diff_lines_output)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"[Diff generation error: {e}]"
|
return f"[Diff generation error: {e}]"
|
||||||
|
|
|
@ -25,15 +25,23 @@ from aider.utils import Spinner
|
||||||
warnings.simplefilter("ignore", category=FutureWarning)
|
warnings.simplefilter("ignore", category=FutureWarning)
|
||||||
from grep_ast.tsl import USING_TSL_PACK, get_language, get_parser # noqa: E402
|
from grep_ast.tsl import USING_TSL_PACK, get_language, get_parser # noqa: E402
|
||||||
|
|
||||||
Tag = namedtuple("Tag", "rel_fname fname line name kind".split())
|
# Define the Tag namedtuple with a default for specific_kind to maintain compatibility
|
||||||
|
# with cached entries that might have been created with the old definition
|
||||||
|
class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind")):
|
||||||
|
__slots__ = ()
|
||||||
|
def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None):
|
||||||
|
# Provide a default value for specific_kind to handle old cached objects
|
||||||
|
return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind)
|
||||||
|
|
||||||
|
Tag = TagBase
|
||||||
|
|
||||||
|
|
||||||
SQLITE_ERRORS = (sqlite3.OperationalError, sqlite3.DatabaseError, OSError)
|
SQLITE_ERRORS = (sqlite3.OperationalError, sqlite3.DatabaseError, OSError)
|
||||||
|
|
||||||
|
|
||||||
CACHE_VERSION = 3
|
CACHE_VERSION = 5
|
||||||
if USING_TSL_PACK:
|
if USING_TSL_PACK:
|
||||||
CACHE_VERSION = 4
|
CACHE_VERSION = 6
|
||||||
|
|
||||||
|
|
||||||
class RepoMap:
|
class RepoMap:
|
||||||
|
@ -41,6 +49,17 @@ class RepoMap:
|
||||||
|
|
||||||
warned_files = set()
|
warned_files = set()
|
||||||
|
|
||||||
|
# Define kinds that typically represent definitions across languages
|
||||||
|
# Used by NavigatorCoder to filter tags for the symbol outline
|
||||||
|
definition_kinds = {
|
||||||
|
"class", "struct", "enum", "interface", "trait", # Structure definitions
|
||||||
|
"function", "method", "constructor", # Function/method definitions
|
||||||
|
"module", "namespace", # Module/namespace definitions
|
||||||
|
"constant", "variable", # Top-level/class variable definitions (consider refining)
|
||||||
|
"type", # Type definitions
|
||||||
|
# Add more based on tree-sitter queries if needed
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
map_tokens=1024,
|
map_tokens=1024,
|
||||||
|
@ -242,10 +261,23 @@ class RepoMap:
|
||||||
|
|
||||||
if val is not None and val.get("mtime") == file_mtime:
|
if val is not None and val.get("mtime") == file_mtime:
|
||||||
try:
|
try:
|
||||||
return self.TAGS_CACHE[cache_key]["data"]
|
# Get the cached data
|
||||||
|
data = self.TAGS_CACHE[cache_key]["data"]
|
||||||
|
|
||||||
|
# Let our Tag class handle compatibility with old cache formats
|
||||||
|
# No need for special handling as TagBase.__new__ will supply default specific_kind
|
||||||
|
|
||||||
|
return data
|
||||||
except SQLITE_ERRORS as e:
|
except SQLITE_ERRORS as e:
|
||||||
self.tags_cache_error(e)
|
self.tags_cache_error(e)
|
||||||
return self.TAGS_CACHE[cache_key]["data"]
|
return self.TAGS_CACHE[cache_key]["data"]
|
||||||
|
except (TypeError, AttributeError) as e:
|
||||||
|
# If we hit an error related to missing fields in old cached Tag objects,
|
||||||
|
# force a cache refresh for this file
|
||||||
|
if self.verbose:
|
||||||
|
self.io.tool_warning(f"Cache format error for {fname}, refreshing: {e}")
|
||||||
|
# Return empty list to trigger cache refresh
|
||||||
|
return []
|
||||||
|
|
||||||
# miss!
|
# miss!
|
||||||
data = list(self.get_tags_raw(fname, rel_fname))
|
data = list(self.get_tags_raw(fname, rel_fname))
|
||||||
|
@ -304,11 +336,15 @@ class RepoMap:
|
||||||
|
|
||||||
saw.add(kind)
|
saw.add(kind)
|
||||||
|
|
||||||
|
# Extract specific kind from the tag, e.g., 'function' from 'name.definition.function'
|
||||||
|
specific_kind = tag.split('.')[-1] if '.' in tag else None
|
||||||
|
|
||||||
result = Tag(
|
result = Tag(
|
||||||
rel_fname=rel_fname,
|
rel_fname=rel_fname,
|
||||||
fname=fname,
|
fname=fname,
|
||||||
name=node.text.decode("utf-8"),
|
name=node.text.decode("utf-8"),
|
||||||
kind=kind,
|
kind=kind,
|
||||||
|
specific_kind=specific_kind,
|
||||||
line=node.start_point[0],
|
line=node.start_point[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -338,6 +374,7 @@ class RepoMap:
|
||||||
fname=fname,
|
fname=fname,
|
||||||
name=token,
|
name=token,
|
||||||
kind="ref",
|
kind="ref",
|
||||||
|
specific_kind="name", # Default for pygments fallback
|
||||||
line=-1,
|
line=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue