feat: Improve RepoMap caching with MD5 and error handling

This commit is contained in:
Amir Elaguizy (aider) 2025-04-28 14:24:10 -05:00
parent e205629a94
commit 300854ac58

View file

@ -7,6 +7,7 @@ import sqlite3
import sys import sys
import time import time
import warnings import warnings
import hashlib
from collections import Counter, defaultdict, namedtuple from collections import Counter, defaultdict, namedtuple
from importlib import resources from importlib import resources
from pathlib import Path from pathlib import Path
@ -221,6 +222,10 @@ class RepoMap:
def save_tags_cache(self): def save_tags_cache(self):
pass pass
def _calculate_md5(self, content_bytes):
"""Calculates the MD5 hash of byte content."""
return hashlib.md5(content_bytes).hexdigest()
def get_mtime(self, fname): def get_mtime(self, fname):
try: try:
return os.path.getmtime(fname) return os.path.getmtime(fname)
@ -228,39 +233,106 @@ class RepoMap:
self.io.tool_warning(f"File not found error: {fname}") self.io.tool_warning(f"File not found error: {fname}")
def get_tags(self, fname, rel_fname): def get_tags(self, fname, rel_fname):
# Check if the file is in the cache and if the modification time has not changed # 1. Get mtime and read content (handle errors)
file_mtime = self.get_mtime(fname)
if file_mtime is None:
return []
cache_key = fname
try: try:
val = self.TAGS_CACHE.get(cache_key) # Issue #1308 file_mtime = os.path.getmtime(fname)
except SQLITE_ERRORS as e: # Read as bytes for consistent hashing
self.tags_cache_error(e) content_bytes = Path(fname).read_bytes()
val = self.TAGS_CACHE.get(cache_key) # Decode for parsing, handle potential errors
code = content_bytes.decode(self.io.encoding, errors='replace')
if val is not None and val.get("mtime") == file_mtime: except (FileNotFoundError, IsADirectoryError, OSError) as e:
# File inaccessible, warn and ensure it's removed from cache if present
if fname not in self.warned_files: # Avoid repeated warnings
self.io.tool_warning(f"RepoMap: Error accessing file {fname}: {e}")
self.warned_files.add(fname)
cache_key = fname
try: try:
return self.TAGS_CACHE[cache_key]["data"] # Use TAGS_CACHE.pop() which handles key errors gracefully
except SQLITE_ERRORS as e: self.TAGS_CACHE.pop(cache_key, None)
self.tags_cache_error(e) except SQLITE_ERRORS as sql_e:
return self.TAGS_CACHE[cache_key]["data"] self.tags_cache_error(sql_e)
# If fallback to dict, try pop again
if isinstance(self.TAGS_CACHE, dict):
self.TAGS_CACHE.pop(cache_key, None)
return [] # Return empty list if file can't be accessed
# miss! # 2. Calculate MD5
data = list(self.get_tags_raw(fname, rel_fname)) current_md5 = self._calculate_md5(content_bytes)
# Update the cache # 3. Check cache
cache_key = fname
cached_data = None
try: try:
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data} # Use get() with a default to avoid KeyError if key missing
self.save_tags_cache() cached_data = self.TAGS_CACHE.get(cache_key, None)
except SQLITE_ERRORS as e: except SQLITE_ERRORS as e:
self.tags_cache_error(e) self.tags_cache_error(e)
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data} # Try get() again after potential recovery
if isinstance(self.TAGS_CACHE, dict):
cached_data = self.TAGS_CACHE.get(cache_key, None)
return data # 4. Compare mtime and MD5
cache_valid = False
log_message = None
if cached_data:
cached_mtime = cached_data.get("mtime")
cached_md5 = cached_data.get("md5")
def get_tags_raw(self, fname, rel_fname): if cached_mtime == file_mtime:
if cached_md5 == current_md5:
# mtime and MD5 match, cache is valid
cache_valid = True
else:
# MD5 mismatch! Content changed despite same mtime.
log_message = (
f"RepoMap: Content change detected for {rel_fname} (MD5 mismatch,"
f" mtime {file_mtime}). Re-parsing."
)
else:
# mtime mismatch - file definitely changed or cache is stale
log_message = (
f"RepoMap: File change detected for {rel_fname} (mtime mismatch: cached"
f" {cached_mtime}, current {file_mtime}). Re-parsing."
)
# 5. Return cached data or re-parse
if cache_valid:
try:
# Ensure data exists in the valid cache entry
return cached_data.get("data", [])
except Exception as e:
# Handle potential issues reading cached data, force re-parse
self.io.tool_warning(f"RepoMap: Error reading cached data for {fname}: {e}")
cache_valid = False # Force re-parse
# Cache is invalid or file changed - log if needed and re-parse
if log_message:
self.io.tool_warning(log_message)
# Call the raw tag parsing function (passing the already read code)
tags_data = list(self.get_tags_raw(fname, rel_fname, code)) # Pass code here
# Update the cache with new mtime, md5, and data
try:
self.TAGS_CACHE[cache_key] = {
"mtime": file_mtime,
"md5": current_md5,
"data": tags_data,
}
# self.save_tags_cache() # save_tags_cache might be a no-op now with diskcache
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
# Try saving to in-memory dict if disk cache failed
if isinstance(self.TAGS_CACHE, dict):
self.TAGS_CACHE[cache_key] = {
"mtime": file_mtime,
"md5": current_md5,
"data": tags_data,
}
return tags_data
def get_tags_raw(self, fname, rel_fname, code):
lang = filename_to_lang(fname) lang = filename_to_lang(fname)
if not lang: if not lang:
return return
@ -269,22 +341,40 @@ class RepoMap:
language = get_language(lang) language = get_language(lang)
parser = get_parser(lang) parser = get_parser(lang)
except Exception as err: except Exception as err:
print(f"Skipping file {fname}: {err}") # Use io object for output
self.io.tool_warning(f"Skipping file {fname} for tags: {err}")
return return
query_scm = get_scm_fname(lang) query_scm_path = get_scm_fname(lang) # Renamed variable for clarity
if not query_scm.exists(): if not query_scm_path or not query_scm_path.exists():
# self.io.tool_warning(f"No tags query file found for language: {lang}") # Optional: more verbose logging
return
try:
query_scm = query_scm_path.read_text()
except Exception as e:
self.io.tool_warning(f"Error reading tags query file {query_scm_path}: {e}")
return return
query_scm = query_scm.read_text()
code = self.io.read_text(fname)
if not code: # code = self.io.read_text(fname) # <-- REMOVE THIS LINE
if not code: # Check the passed code content
return return
tree = parser.parse(bytes(code, "utf-8")) try:
tree = parser.parse(bytes(code, "utf-8"))
except Exception as e:
self.io.tool_warning(f"Error parsing {fname} with tree-sitter: {e}")
return
# Run the tags queries # Run the tags queries
query = language.query(query_scm) try:
captures = query.captures(tree.root_node) query = language.query(query_scm)
captures = query.captures(tree.root_node)
except Exception as e:
# Catch errors during query execution as well
self.io.tool_warning(f"Error running tags query on {fname}: {e}")
return
saw = set() saw = set()
if USING_TSL_PACK: if USING_TSL_PACK: