This commit is contained in:
Amir Elaguizy 2025-05-13 16:48:46 -07:00 committed by GitHub
commit f287f777cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 326 additions and 33 deletions

View file

@ -7,6 +7,7 @@ import sqlite3
import sys
import time
import warnings
import hashlib
from collections import Counter, defaultdict, namedtuple
from importlib import resources
from pathlib import Path
@ -223,6 +224,10 @@ class RepoMap:
def save_tags_cache(self):
pass
def _calculate_md5(self, content_bytes):
"""Calculates the MD5 hash of byte content."""
return hashlib.md5(content_bytes).hexdigest()
def get_mtime(self, fname):
try:
return os.path.getmtime(fname)
@ -230,39 +235,106 @@ class RepoMap:
self.io.tool_warning(f"File not found error: {fname}")
def get_tags(self, fname, rel_fname):
# Check if the file is in the cache and if the modification time has not changed
file_mtime = self.get_mtime(fname)
if file_mtime is None:
return []
cache_key = fname
# 1. Get mtime and read content (handle errors)
try:
val = self.TAGS_CACHE.get(cache_key) # Issue #1308
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
val = self.TAGS_CACHE.get(cache_key)
if val is not None and val.get("mtime") == file_mtime:
file_mtime = os.path.getmtime(fname)
# Read as bytes for consistent hashing
content_bytes = Path(fname).read_bytes()
# Decode for parsing, handle potential errors
code = content_bytes.decode(self.io.encoding, errors='replace')
except (FileNotFoundError, IsADirectoryError, OSError) as e:
# File inaccessible, warn and ensure it's removed from cache if present
if fname not in self.warned_files: # Avoid repeated warnings
self.io.tool_warning(f"RepoMap: Error accessing file {fname}: {e}")
self.warned_files.add(fname)
cache_key = fname
try:
return self.TAGS_CACHE[cache_key]["data"]
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
return self.TAGS_CACHE[cache_key]["data"]
# Use TAGS_CACHE.pop() which handles key errors gracefully
self.TAGS_CACHE.pop(cache_key, None)
except SQLITE_ERRORS as sql_e:
self.tags_cache_error(sql_e)
# If fallback to dict, try pop again
if isinstance(self.TAGS_CACHE, dict):
self.TAGS_CACHE.pop(cache_key, None)
return [] # Return empty list if file can't be accessed
# miss!
data = list(self.get_tags_raw(fname, rel_fname))
# 2. Calculate MD5
current_md5 = self._calculate_md5(content_bytes)
# Update the cache
# 3. Check cache
cache_key = fname
cached_data = None
try:
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data}
self.save_tags_cache()
# Use get() with a default to avoid KeyError if key missing
cached_data = self.TAGS_CACHE.get(cache_key, None)
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data}
# Try get() again after potential recovery
if isinstance(self.TAGS_CACHE, dict):
cached_data = self.TAGS_CACHE.get(cache_key, None)
return data
# 4. Compare mtime and MD5
cache_valid = False
log_message = None
if cached_data:
cached_mtime = cached_data.get("mtime")
cached_md5 = cached_data.get("md5")
def get_tags_raw(self, fname, rel_fname):
if cached_mtime == file_mtime:
if cached_md5 == current_md5:
# mtime and MD5 match, cache is valid
cache_valid = True
else:
# MD5 mismatch! Content changed despite same mtime.
log_message = (
f"RepoMap: Content change detected for {rel_fname} (MD5 mismatch,"
f" mtime {file_mtime}). Re-parsing."
)
else:
# mtime mismatch - file definitely changed or cache is stale
log_message = (
f"RepoMap: File change detected for {rel_fname} (mtime mismatch: cached"
f" {cached_mtime}, current {file_mtime}). Re-parsing."
)
# 5. Return cached data or re-parse
if cache_valid:
try:
# Ensure data exists in the valid cache entry
return cached_data.get("data", [])
except Exception as e:
# Handle potential issues reading cached data, force re-parse
self.io.tool_warning(f"RepoMap: Error reading cached data for {fname}: {e}")
cache_valid = False # Force re-parse
# Cache is invalid or file changed - log if needed and re-parse
if log_message:
self.io.tool_warning(log_message)
# Call the raw tag parsing function (passing the already read code)
tags_data = list(self.get_tags_raw(fname, rel_fname, code)) # Pass code here
# Update the cache with new mtime, md5, and data
try:
self.TAGS_CACHE[cache_key] = {
"mtime": file_mtime,
"md5": current_md5,
"data": tags_data,
}
# self.save_tags_cache() # save_tags_cache might be a no-op now with diskcache
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
# Try saving to in-memory dict if disk cache failed
if isinstance(self.TAGS_CACHE, dict):
self.TAGS_CACHE[cache_key] = {
"mtime": file_mtime,
"md5": current_md5,
"data": tags_data,
}
return tags_data
def get_tags_raw(self, fname, rel_fname, code):
lang = filename_to_lang(fname)
if not lang:
return
@ -271,22 +343,40 @@ class RepoMap:
language = get_language(lang)
parser = get_parser(lang)
except Exception as err:
print(f"Skipping file {fname}: {err}")
# Use io object for output
self.io.tool_warning(f"Skipping file {fname} for tags: {err}")
return
query_scm = get_scm_fname(lang)
if not query_scm.exists():
query_scm_path = get_scm_fname(lang) # Renamed variable for clarity
if not query_scm_path or not query_scm_path.exists():
# self.io.tool_warning(f"No tags query file found for language: {lang}") # Optional: more verbose logging
return
try:
query_scm = query_scm_path.read_text()
except Exception as e:
self.io.tool_warning(f"Error reading tags query file {query_scm_path}: {e}")
return
query_scm = query_scm.read_text()
code = self.io.read_text(fname)
if not code:
# code = self.io.read_text(fname) # <-- REMOVE THIS LINE
if not code: # Check the passed code content
return
tree = parser.parse(bytes(code, "utf-8"))
try:
tree = parser.parse(bytes(code, "utf-8"))
except Exception as e:
self.io.tool_warning(f"Error parsing {fname} with tree-sitter: {e}")
return
# Run the tags queries
query = language.query(query_scm)
captures = query.captures(tree.root_node)
try:
query = language.query(query_scm)
captures = query.captures(tree.root_node)
except Exception as e:
# Catch errors during query execution as well
self.io.tool_warning(f"Error running tags query on {fname}: {e}")
return
saw = set()
if USING_TSL_PACK:

View file

@ -204,6 +204,42 @@
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"gemini/gemini-2.5-pro-preview-05-06": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 64000,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_image": 0,
"input_cost_per_video_per_second": 0,
"input_cost_per_audio_per_second": 0,
"input_cost_per_token": 0.00000125,
"input_cost_per_character": 0,
"input_cost_per_token_above_128k_tokens": 0,
"input_cost_per_character_above_128k_tokens": 0,
"input_cost_per_image_above_128k_tokens": 0,
"input_cost_per_video_per_second_above_128k_tokens": 0,
"input_cost_per_audio_per_second_above_128k_tokens": 0,
"output_cost_per_token": 0.000010,
"output_cost_per_character": 0,
"output_cost_per_token_above_128k_tokens": 0,
"output_cost_per_character_above_128k_tokens": 0,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_audio_input": true,
"supports_video_input": true,
"supports_pdf_input": true,
"supports_response_schema": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/gemini-2.5-pro-exp-03-25": {
"max_tokens": 8192,
"max_input_tokens": 1048576,

View file

@ -963,6 +963,12 @@
use_repo_map: true
weak_model_name: gemini/gemini-2.0-flash
- name: gemini/gemini-2.5-pro-preview-05-06
overeager: true
edit_format: diff-fenced
use_repo_map: true
weak_model_name: gemini/gemini-2.0-flash
- name: gemini/gemini-2.5-pro-exp-03-25
edit_format: diff-fenced
use_repo_map: true

View file

@ -4,6 +4,7 @@ import re
import time
import unittest
from pathlib import Path
from unittest.mock import patch
import git
@ -18,7 +19,167 @@ class TestRepoMap(unittest.TestCase):
def setUp(self):
self.GPT35 = Model("gpt-3.5-turbo")
# Helper function to calculate MD5 hash of a file
def _calculate_md5_for_file(self, file_path):
import hashlib
hasher = hashlib.md5()
with open(file_path, 'rb') as f:
while True:
chunk = f.read(8192)
if not chunk:
break
hasher.update(chunk)
return hasher.hexdigest()
def test_get_repo_map(self):
pass
@patch("aider.io.InputOutput.tool_warning")
def test_get_tags_md5_change_same_mtime(self, mock_tool_warning):
"""Verify MD5 detection when mtime is unchanged."""
with GitTemporaryDirectory() as temp_dir:
# Create a test file
test_file = Path(temp_dir) / "test.py"
initial_content = "def func_a(): pass\n"
test_file.write_text(initial_content)
abs_path = str(test_file.resolve())
rel_path = "test.py"
# Initialize RepoMap and populate cache
io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
initial_tags = repo_map.get_tags(abs_path, rel_path)
self.assertTrue(any(tag.name == "func_a" for tag in initial_tags))
initial_mtime = os.path.getmtime(abs_path)
# Modify content, reset mtime
new_content = "def func_b(): pass\n"
test_file.write_text(new_content)
os.utime(abs_path, (initial_mtime, initial_mtime)) # Reset mtime
# Call get_tags again
new_tags = repo_map.get_tags(abs_path, rel_path)
# Assertions
mock_tool_warning.assert_called_once()
self.assertIn("MD5 mismatch", mock_tool_warning.call_args[0][0])
self.assertTrue(any(tag.name == "func_b" for tag in new_tags))
self.assertFalse(any(tag.name == "func_a" for tag in new_tags))
# Check cache update
cached_data = repo_map.TAGS_CACHE.get(abs_path)
self.assertIsNotNone(cached_data)
expected_md5 = self._calculate_md5_for_file(abs_path)
self.assertEqual(cached_data.get("md5"), expected_md5)
self.assertEqual(cached_data.get("mtime"), initial_mtime)
del repo_map # Close cache
@patch("aider.io.InputOutput.tool_warning")
@patch("aider.repomap.RepoMap.get_tags_raw")
def test_get_tags_no_change(self, mock_get_tags_raw, mock_tool_warning):
"""Verify cache is used when file is unchanged."""
with GitTemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.py"
initial_content = "def func_a(): pass\n"
test_file.write_text(initial_content)
abs_path = str(test_file.resolve())
rel_path = "test.py"
io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
# Initial call to populate cache
initial_tags = repo_map.get_tags(abs_path, rel_path)
mock_get_tags_raw.assert_called_once() # Called once initially
mock_get_tags_raw.reset_mock() # Reset for the next check
# Call get_tags again without changes
second_tags = repo_map.get_tags(abs_path, rel_path)
# Assertions
mock_tool_warning.assert_not_called()
mock_get_tags_raw.assert_not_called() # Should not be called again
self.assertEqual(initial_tags, second_tags)
del repo_map # Close cache
@patch("aider.io.InputOutput.tool_warning")
def test_get_tags_mtime_change(self, mock_tool_warning):
"""Verify standard mtime-based change detection still works."""
with GitTemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.py"
initial_content = "def func_a(): pass\n"
test_file.write_text(initial_content)
abs_path = str(test_file.resolve())
rel_path = "test.py"
io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
# Initial call
initial_tags = repo_map.get_tags(abs_path, rel_path)
self.assertTrue(any(tag.name == "func_a" for tag in initial_tags))
# Modify content (mtime will change naturally)
time.sleep(0.01) # Ensure mtime is different
new_content = "def func_b(): pass\n"
test_file.write_text(new_content)
new_mtime = os.path.getmtime(abs_path)
# Call get_tags again
new_tags = repo_map.get_tags(abs_path, rel_path)
# Assertions
mock_tool_warning.assert_called_once()
self.assertIn("mtime mismatch", mock_tool_warning.call_args[0][0])
self.assertTrue(any(tag.name == "func_b" for tag in new_tags))
self.assertFalse(any(tag.name == "func_a" for tag in new_tags))
# Check cache update
cached_data = repo_map.TAGS_CACHE.get(abs_path)
self.assertIsNotNone(cached_data)
expected_md5 = self._calculate_md5_for_file(abs_path)
self.assertEqual(cached_data.get("md5"), expected_md5)
self.assertEqual(cached_data.get("mtime"), new_mtime)
del repo_map # Close cache
@patch("aider.io.InputOutput.tool_warning")
def test_get_tags_file_not_found_after_cache(self, mock_tool_warning):
"""Verify graceful handling if a cached file becomes inaccessible."""
with GitTemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.py"
test_file.write_text("def func_a(): pass\n")
abs_path = str(test_file.resolve())
rel_path = "test.py"
io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
# Populate cache
repo_map.get_tags(abs_path, rel_path)
self.assertIn(abs_path, repo_map.TAGS_CACHE)
# Delete the file
os.remove(abs_path)
# Call get_tags again
result = repo_map.get_tags(abs_path, rel_path)
# Assertions
mock_tool_warning.assert_called()
# Check if any call contains "Error accessing file" or "FileNotFoundError"
warning_found = any(
"Error accessing file" in call[0][0] or "FileNotFoundError" in call[0][0]
for call in mock_tool_warning.call_args_list
)
self.assertTrue(warning_found, "Expected file access error warning not found")
self.assertEqual(result, [])
self.assertNotIn(abs_path, repo_map.TAGS_CACHE)
del repo_map # Close cache
# Create a temporary directory with sample files for testing
test_files = [
"test_file1.py",