From 1e01482be6b1c24c2cf109751b1331e234a1f8d3 Mon Sep 17 00:00:00 2001 From: "Amar Sood (tekacs)" Date: Sat, 12 Apr 2025 05:48:25 -0400 Subject: [PATCH] Include a tree-sitter based outline of open files for the LLM --- aider/coders/navigator_coder.py | 122 +++++++++++++++++++++++++++----- aider/repomap.py | 45 ++++++++++-- 2 files changed, 147 insertions(+), 20 deletions(-) diff --git a/aider/coders/navigator_coder.py b/aider/coders/navigator_coder.py index 3495cba6c..6f824551b 100644 --- a/aider/coders/navigator_coder.py +++ b/aider/coders/navigator_coder.py @@ -7,11 +7,22 @@ import random import subprocess import traceback import platform +import ast +import re +import fnmatch +import os +import time +import random +import subprocess +import traceback +import platform import locale from datetime import datetime from pathlib import Path import xml.etree.ElementTree as ET from xml.etree.ElementTree import ParseError +# Add necessary imports if not already present +from collections import defaultdict from .base_coder import Coder from .editblock_coder import find_original_update_blocks, do_replace, find_similar_lines @@ -61,28 +72,104 @@ class NavigatorCoder(Coder): # Enable enhanced context blocks by default self.use_enhanced_context = True + def get_context_symbol_outline(self): + """ + Generate a symbol outline for files currently in context using Tree-sitter, + bypassing the cache for freshness. + """ + if not self.use_enhanced_context or not self.repo_map: + return None + + try: + result = "\n" + result += "## Symbol Outline (Current Context)\n\n" + result += "Code definitions (classes, functions, methods, etc.) found in files currently in chat context.\n\n" + + files_to_outline = list(self.abs_fnames) + list(self.abs_read_only_fnames) + if not files_to_outline: + result += "No files currently in context.\n" + result += "" + return result + + all_tags_by_file = defaultdict(list) + has_symbols = False + + # Use repo_map which should be initialized in BaseCoder + if not self.repo_map: + self.io.tool_warning("RepoMap not initialized, cannot generate symbol outline.") + return None # Or return a message indicating repo map is unavailable + + for abs_fname in sorted(files_to_outline): + rel_fname = self.get_rel_fname(abs_fname) + try: + # Call get_tags_raw directly to bypass cache and ensure freshness + tags = list(self.repo_map.get_tags_raw(abs_fname, rel_fname)) + if tags: + all_tags_by_file[rel_fname].extend(tags) + has_symbols = True + except Exception as e: + self.io.tool_warning(f"Could not get symbols for {rel_fname}: {e}") + + if not has_symbols: + result += "No symbols found in the current context files.\n" + else: + for rel_fname in sorted(all_tags_by_file.keys()): + tags = sorted(all_tags_by_file[rel_fname], key=lambda t: (t.line, t.name)) + + definition_tags = [] + for tag in tags: + # Use specific_kind first if available, otherwise fall back to kind + kind_to_check = tag.specific_kind or tag.kind + # Check if the kind represents a definition using the set from RepoMap + if kind_to_check and kind_to_check.lower() in self.repo_map.definition_kinds: + definition_tags.append(tag) + + if definition_tags: + result += f"### {rel_fname}\n" + # Simple list format for now, could be enhanced later (e.g., indentation for scope) + for tag in definition_tags: + # Display line number if available + line_info = f", line {tag.line + 1}" if tag.line >= 0 else "" + # Display the specific kind (which we checked) + kind_to_check = tag.specific_kind or tag.kind # Recalculate for safety + result += f"- {tag.name} ({kind_to_check}{line_info})\n" + result += "\n" # Add space between files + + result += "" + return result.strip() # Remove trailing newline if any + + except Exception as e: + self.io.tool_error(f"Error generating symbol outline: {str(e)}") + # Optionally include traceback for debugging if verbose + # if self.verbose: + # self.io.tool_error(traceback.format_exc()) + return None + def format_chat_chunks(self): """ - Override parent's format_chat_chunks to include enhanced context blocks with a + Override parent's format_chat_chunks to include enhanced context blocks with a cleaner, more hierarchical structure for better organization. """ # First get the normal chat chunks from the parent method - chunks = super().format_chat_chunks() - + chunks = super().format_chat_chunks() # Calls BaseCoder's format_chat_chunks + # If enhanced context blocks are enabled, prepend them to the current messages if self.use_enhanced_context: # Create environment info context block env_context = self.get_environment_info() - - # Get directory structure - dir_structure = self.get_directory_structure() - - # Get git status - git_status = self.get_git_status() - + # Get current context summary context_summary = self.get_context_summary() - + + # Get directory structure + dir_structure = self.get_directory_structure() + + # Get git status + git_status = self.get_git_status() + + # Get symbol outline for current context files + symbol_outline = self.get_context_symbol_outline() # New call + # Collect all context blocks that exist context_blocks = [] if env_context: @@ -93,21 +180,24 @@ class NavigatorCoder(Coder): context_blocks.append(dir_structure) if git_status: context_blocks.append(git_status) - - # If we have any context blocks, prepend them to the current messages + if symbol_outline: # Add the new block if it was generated + context_blocks.append(symbol_outline) + + # If we have any context blocks, prepend them to the system message if context_blocks: context_message = "\n\n".join(context_blocks) # Prepend to system context but don't overwrite existing system content if chunks.system: # If we already have system messages, append our context to the first one original_content = chunks.system[0]["content"] + # Ensure there's separation between our blocks and the original prompt chunks.system[0]["content"] = context_message + "\n\n" + original_content else: # Otherwise, create a new system message chunks.system = [dict(role="system", content=context_message)] - + return chunks - + def get_context_summary(self): """ Generate a summary of the current file context, including editable and read-only files, @@ -3317,4 +3407,4 @@ Just reply with fixed versions of the {blocks} above that failed to match. return "\n".join(diff_lines_output) except Exception as e: - return f"[Diff generation error: {e}]" \ No newline at end of file + return f"[Diff generation error: {e}]" diff --git a/aider/repomap.py b/aider/repomap.py index 598770d18..b21d65f02 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -25,15 +25,23 @@ from aider.utils import Spinner warnings.simplefilter("ignore", category=FutureWarning) from grep_ast.tsl import USING_TSL_PACK, get_language, get_parser # noqa: E402 -Tag = namedtuple("Tag", "rel_fname fname line name kind".split()) +# Define the Tag namedtuple with a default for specific_kind to maintain compatibility +# with cached entries that might have been created with the old definition +class TagBase(namedtuple("TagBase", "rel_fname fname line name kind specific_kind")): + __slots__ = () + def __new__(cls, rel_fname, fname, line, name, kind, specific_kind=None): + # Provide a default value for specific_kind to handle old cached objects + return super(TagBase, cls).__new__(cls, rel_fname, fname, line, name, kind, specific_kind) + +Tag = TagBase SQLITE_ERRORS = (sqlite3.OperationalError, sqlite3.DatabaseError, OSError) -CACHE_VERSION = 3 +CACHE_VERSION = 5 if USING_TSL_PACK: - CACHE_VERSION = 4 + CACHE_VERSION = 6 class RepoMap: @@ -41,6 +49,17 @@ class RepoMap: warned_files = set() + # Define kinds that typically represent definitions across languages + # Used by NavigatorCoder to filter tags for the symbol outline + definition_kinds = { + "class", "struct", "enum", "interface", "trait", # Structure definitions + "function", "method", "constructor", # Function/method definitions + "module", "namespace", # Module/namespace definitions + "constant", "variable", # Top-level/class variable definitions (consider refining) + "type", # Type definitions + # Add more based on tree-sitter queries if needed + } + def __init__( self, map_tokens=1024, @@ -242,10 +261,23 @@ class RepoMap: if val is not None and val.get("mtime") == file_mtime: try: - return self.TAGS_CACHE[cache_key]["data"] + # Get the cached data + data = self.TAGS_CACHE[cache_key]["data"] + + # Let our Tag class handle compatibility with old cache formats + # No need for special handling as TagBase.__new__ will supply default specific_kind + + return data except SQLITE_ERRORS as e: self.tags_cache_error(e) return self.TAGS_CACHE[cache_key]["data"] + except (TypeError, AttributeError) as e: + # If we hit an error related to missing fields in old cached Tag objects, + # force a cache refresh for this file + if self.verbose: + self.io.tool_warning(f"Cache format error for {fname}, refreshing: {e}") + # Return empty list to trigger cache refresh + return [] # miss! data = list(self.get_tags_raw(fname, rel_fname)) @@ -304,11 +336,15 @@ class RepoMap: saw.add(kind) + # Extract specific kind from the tag, e.g., 'function' from 'name.definition.function' + specific_kind = tag.split('.')[-1] if '.' in tag else None + result = Tag( rel_fname=rel_fname, fname=fname, name=node.text.decode("utf-8"), kind=kind, + specific_kind=specific_kind, line=node.start_point[0], ) @@ -338,6 +374,7 @@ class RepoMap: fname=fname, name=token, kind="ref", + specific_kind="name", # Default for pygments fallback line=-1, )