feat: Improve RepoMap caching with MD5 and error handling

This commit is contained in:
Amir Elaguizy (aider) 2025-04-28 14:24:10 -05:00
parent e205629a94
commit 300854ac58

View file

@ -7,6 +7,7 @@ import sqlite3
import sys
import time
import warnings
import hashlib
from collections import Counter, defaultdict, namedtuple
from importlib import resources
from pathlib import Path
@ -221,6 +222,10 @@ class RepoMap:
def save_tags_cache(self):
pass
def _calculate_md5(self, content_bytes):
"""Calculates the MD5 hash of byte content."""
return hashlib.md5(content_bytes).hexdigest()
def get_mtime(self, fname):
try:
return os.path.getmtime(fname)
@ -228,39 +233,106 @@ class RepoMap:
self.io.tool_warning(f"File not found error: {fname}")
def get_tags(self, fname, rel_fname):
# Check if the file is in the cache and if the modification time has not changed
file_mtime = self.get_mtime(fname)
if file_mtime is None:
return []
cache_key = fname
# 1. Get mtime and read content (handle errors)
try:
val = self.TAGS_CACHE.get(cache_key) # Issue #1308
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
val = self.TAGS_CACHE.get(cache_key)
if val is not None and val.get("mtime") == file_mtime:
file_mtime = os.path.getmtime(fname)
# Read as bytes for consistent hashing
content_bytes = Path(fname).read_bytes()
# Decode for parsing, handle potential errors
code = content_bytes.decode(self.io.encoding, errors='replace')
except (FileNotFoundError, IsADirectoryError, OSError) as e:
# File inaccessible, warn and ensure it's removed from cache if present
if fname not in self.warned_files: # Avoid repeated warnings
self.io.tool_warning(f"RepoMap: Error accessing file {fname}: {e}")
self.warned_files.add(fname)
cache_key = fname
try:
return self.TAGS_CACHE[cache_key]["data"]
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
return self.TAGS_CACHE[cache_key]["data"]
# Use TAGS_CACHE.pop() which handles key errors gracefully
self.TAGS_CACHE.pop(cache_key, None)
except SQLITE_ERRORS as sql_e:
self.tags_cache_error(sql_e)
# If fallback to dict, try pop again
if isinstance(self.TAGS_CACHE, dict):
self.TAGS_CACHE.pop(cache_key, None)
return [] # Return empty list if file can't be accessed
# miss!
data = list(self.get_tags_raw(fname, rel_fname))
# 2. Calculate MD5
current_md5 = self._calculate_md5(content_bytes)
# Update the cache
# 3. Check cache
cache_key = fname
cached_data = None
try:
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data}
self.save_tags_cache()
# Use get() with a default to avoid KeyError if key missing
cached_data = self.TAGS_CACHE.get(cache_key, None)
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data}
# Try get() again after potential recovery
if isinstance(self.TAGS_CACHE, dict):
cached_data = self.TAGS_CACHE.get(cache_key, None)
return data
# 4. Compare mtime and MD5
cache_valid = False
log_message = None
if cached_data:
cached_mtime = cached_data.get("mtime")
cached_md5 = cached_data.get("md5")
def get_tags_raw(self, fname, rel_fname):
if cached_mtime == file_mtime:
if cached_md5 == current_md5:
# mtime and MD5 match, cache is valid
cache_valid = True
else:
# MD5 mismatch! Content changed despite same mtime.
log_message = (
f"RepoMap: Content change detected for {rel_fname} (MD5 mismatch,"
f" mtime {file_mtime}). Re-parsing."
)
else:
# mtime mismatch - file definitely changed or cache is stale
log_message = (
f"RepoMap: File change detected for {rel_fname} (mtime mismatch: cached"
f" {cached_mtime}, current {file_mtime}). Re-parsing."
)
# 5. Return cached data or re-parse
if cache_valid:
try:
# Ensure data exists in the valid cache entry
return cached_data.get("data", [])
except Exception as e:
# Handle potential issues reading cached data, force re-parse
self.io.tool_warning(f"RepoMap: Error reading cached data for {fname}: {e}")
cache_valid = False # Force re-parse
# Cache is invalid or file changed - log if needed and re-parse
if log_message:
self.io.tool_warning(log_message)
# Call the raw tag parsing function (passing the already read code)
tags_data = list(self.get_tags_raw(fname, rel_fname, code)) # Pass code here
# Update the cache with new mtime, md5, and data
try:
self.TAGS_CACHE[cache_key] = {
"mtime": file_mtime,
"md5": current_md5,
"data": tags_data,
}
# self.save_tags_cache() # save_tags_cache might be a no-op now with diskcache
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
# Try saving to in-memory dict if disk cache failed
if isinstance(self.TAGS_CACHE, dict):
self.TAGS_CACHE[cache_key] = {
"mtime": file_mtime,
"md5": current_md5,
"data": tags_data,
}
return tags_data
def get_tags_raw(self, fname, rel_fname, code):
lang = filename_to_lang(fname)
if not lang:
return
@ -269,22 +341,40 @@ class RepoMap:
language = get_language(lang)
parser = get_parser(lang)
except Exception as err:
print(f"Skipping file {fname}: {err}")
# Use io object for output
self.io.tool_warning(f"Skipping file {fname} for tags: {err}")
return
query_scm = get_scm_fname(lang)
if not query_scm.exists():
query_scm_path = get_scm_fname(lang) # Renamed variable for clarity
if not query_scm_path or not query_scm_path.exists():
# self.io.tool_warning(f"No tags query file found for language: {lang}") # Optional: more verbose logging
return
try:
query_scm = query_scm_path.read_text()
except Exception as e:
self.io.tool_warning(f"Error reading tags query file {query_scm_path}: {e}")
return
query_scm = query_scm.read_text()
code = self.io.read_text(fname)
if not code:
# code = self.io.read_text(fname) # <-- REMOVE THIS LINE
if not code: # Check the passed code content
return
tree = parser.parse(bytes(code, "utf-8"))
try:
tree = parser.parse(bytes(code, "utf-8"))
except Exception as e:
self.io.tool_warning(f"Error parsing {fname} with tree-sitter: {e}")
return
# Run the tags queries
query = language.query(query_scm)
captures = query.captures(tree.root_node)
try:
query = language.query(query_scm)
captures = query.captures(tree.root_node)
except Exception as e:
# Catch errors during query execution as well
self.io.tool_warning(f"Error running tags query on {fname}: {e}")
return
saw = set()
if USING_TSL_PACK: