This commit is contained in:
Amir Elaguizy 2025-05-13 16:48:46 -07:00 committed by GitHub
commit f287f777cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 326 additions and 33 deletions

View file

@ -7,6 +7,7 @@ import sqlite3
import sys import sys
import time import time
import warnings import warnings
import hashlib
from collections import Counter, defaultdict, namedtuple from collections import Counter, defaultdict, namedtuple
from importlib import resources from importlib import resources
from pathlib import Path from pathlib import Path
@ -223,6 +224,10 @@ class RepoMap:
def save_tags_cache(self): def save_tags_cache(self):
pass pass
def _calculate_md5(self, content_bytes):
"""Calculates the MD5 hash of byte content."""
return hashlib.md5(content_bytes).hexdigest()
def get_mtime(self, fname): def get_mtime(self, fname):
try: try:
return os.path.getmtime(fname) return os.path.getmtime(fname)
@ -230,39 +235,106 @@ class RepoMap:
self.io.tool_warning(f"File not found error: {fname}") self.io.tool_warning(f"File not found error: {fname}")
def get_tags(self, fname, rel_fname): def get_tags(self, fname, rel_fname):
# Check if the file is in the cache and if the modification time has not changed # 1. Get mtime and read content (handle errors)
file_mtime = self.get_mtime(fname)
if file_mtime is None:
return []
cache_key = fname
try: try:
val = self.TAGS_CACHE.get(cache_key) # Issue #1308 file_mtime = os.path.getmtime(fname)
except SQLITE_ERRORS as e: # Read as bytes for consistent hashing
self.tags_cache_error(e) content_bytes = Path(fname).read_bytes()
val = self.TAGS_CACHE.get(cache_key) # Decode for parsing, handle potential errors
code = content_bytes.decode(self.io.encoding, errors='replace')
if val is not None and val.get("mtime") == file_mtime: except (FileNotFoundError, IsADirectoryError, OSError) as e:
# File inaccessible, warn and ensure it's removed from cache if present
if fname not in self.warned_files: # Avoid repeated warnings
self.io.tool_warning(f"RepoMap: Error accessing file {fname}: {e}")
self.warned_files.add(fname)
cache_key = fname
try: try:
return self.TAGS_CACHE[cache_key]["data"] # Use TAGS_CACHE.pop() which handles key errors gracefully
except SQLITE_ERRORS as e: self.TAGS_CACHE.pop(cache_key, None)
self.tags_cache_error(e) except SQLITE_ERRORS as sql_e:
return self.TAGS_CACHE[cache_key]["data"] self.tags_cache_error(sql_e)
# If fallback to dict, try pop again
if isinstance(self.TAGS_CACHE, dict):
self.TAGS_CACHE.pop(cache_key, None)
return [] # Return empty list if file can't be accessed
# miss! # 2. Calculate MD5
data = list(self.get_tags_raw(fname, rel_fname)) current_md5 = self._calculate_md5(content_bytes)
# Update the cache # 3. Check cache
cache_key = fname
cached_data = None
try: try:
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data} # Use get() with a default to avoid KeyError if key missing
self.save_tags_cache() cached_data = self.TAGS_CACHE.get(cache_key, None)
except SQLITE_ERRORS as e: except SQLITE_ERRORS as e:
self.tags_cache_error(e) self.tags_cache_error(e)
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data} # Try get() again after potential recovery
if isinstance(self.TAGS_CACHE, dict):
cached_data = self.TAGS_CACHE.get(cache_key, None)
return data # 4. Compare mtime and MD5
cache_valid = False
log_message = None
if cached_data:
cached_mtime = cached_data.get("mtime")
cached_md5 = cached_data.get("md5")
def get_tags_raw(self, fname, rel_fname): if cached_mtime == file_mtime:
if cached_md5 == current_md5:
# mtime and MD5 match, cache is valid
cache_valid = True
else:
# MD5 mismatch! Content changed despite same mtime.
log_message = (
f"RepoMap: Content change detected for {rel_fname} (MD5 mismatch,"
f" mtime {file_mtime}). Re-parsing."
)
else:
# mtime mismatch - file definitely changed or cache is stale
log_message = (
f"RepoMap: File change detected for {rel_fname} (mtime mismatch: cached"
f" {cached_mtime}, current {file_mtime}). Re-parsing."
)
# 5. Return cached data or re-parse
if cache_valid:
try:
# Ensure data exists in the valid cache entry
return cached_data.get("data", [])
except Exception as e:
# Handle potential issues reading cached data, force re-parse
self.io.tool_warning(f"RepoMap: Error reading cached data for {fname}: {e}")
cache_valid = False # Force re-parse
# Cache is invalid or file changed - log if needed and re-parse
if log_message:
self.io.tool_warning(log_message)
# Call the raw tag parsing function (passing the already read code)
tags_data = list(self.get_tags_raw(fname, rel_fname, code)) # Pass code here
# Update the cache with new mtime, md5, and data
try:
self.TAGS_CACHE[cache_key] = {
"mtime": file_mtime,
"md5": current_md5,
"data": tags_data,
}
# self.save_tags_cache() # save_tags_cache might be a no-op now with diskcache
except SQLITE_ERRORS as e:
self.tags_cache_error(e)
# Try saving to in-memory dict if disk cache failed
if isinstance(self.TAGS_CACHE, dict):
self.TAGS_CACHE[cache_key] = {
"mtime": file_mtime,
"md5": current_md5,
"data": tags_data,
}
return tags_data
def get_tags_raw(self, fname, rel_fname, code):
lang = filename_to_lang(fname) lang = filename_to_lang(fname)
if not lang: if not lang:
return return
@ -271,22 +343,40 @@ class RepoMap:
language = get_language(lang) language = get_language(lang)
parser = get_parser(lang) parser = get_parser(lang)
except Exception as err: except Exception as err:
print(f"Skipping file {fname}: {err}") # Use io object for output
self.io.tool_warning(f"Skipping file {fname} for tags: {err}")
return return
query_scm = get_scm_fname(lang) query_scm_path = get_scm_fname(lang) # Renamed variable for clarity
if not query_scm.exists(): if not query_scm_path or not query_scm_path.exists():
# self.io.tool_warning(f"No tags query file found for language: {lang}") # Optional: more verbose logging
return
try:
query_scm = query_scm_path.read_text()
except Exception as e:
self.io.tool_warning(f"Error reading tags query file {query_scm_path}: {e}")
return return
query_scm = query_scm.read_text()
code = self.io.read_text(fname)
if not code: # code = self.io.read_text(fname) # <-- REMOVE THIS LINE
if not code: # Check the passed code content
return return
tree = parser.parse(bytes(code, "utf-8")) try:
tree = parser.parse(bytes(code, "utf-8"))
except Exception as e:
self.io.tool_warning(f"Error parsing {fname} with tree-sitter: {e}")
return
# Run the tags queries # Run the tags queries
query = language.query(query_scm) try:
captures = query.captures(tree.root_node) query = language.query(query_scm)
captures = query.captures(tree.root_node)
except Exception as e:
# Catch errors during query execution as well
self.io.tool_warning(f"Error running tags query on {fname}: {e}")
return
saw = set() saw = set()
if USING_TSL_PACK: if USING_TSL_PACK:

View file

@ -204,6 +204,42 @@
"supports_tool_choice": true, "supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
}, },
"gemini/gemini-2.5-pro-preview-05-06": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 64000,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_image": 0,
"input_cost_per_video_per_second": 0,
"input_cost_per_audio_per_second": 0,
"input_cost_per_token": 0.00000125,
"input_cost_per_character": 0,
"input_cost_per_token_above_128k_tokens": 0,
"input_cost_per_character_above_128k_tokens": 0,
"input_cost_per_image_above_128k_tokens": 0,
"input_cost_per_video_per_second_above_128k_tokens": 0,
"input_cost_per_audio_per_second_above_128k_tokens": 0,
"output_cost_per_token": 0.000010,
"output_cost_per_character": 0,
"output_cost_per_token_above_128k_tokens": 0,
"output_cost_per_character_above_128k_tokens": 0,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_audio_input": true,
"supports_video_input": true,
"supports_pdf_input": true,
"supports_response_schema": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/gemini-2.5-pro-exp-03-25": { "vertex_ai/gemini-2.5-pro-exp-03-25": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1048576, "max_input_tokens": 1048576,

View file

@ -963,6 +963,12 @@
use_repo_map: true use_repo_map: true
weak_model_name: gemini/gemini-2.0-flash weak_model_name: gemini/gemini-2.0-flash
- name: gemini/gemini-2.5-pro-preview-05-06
overeager: true
edit_format: diff-fenced
use_repo_map: true
weak_model_name: gemini/gemini-2.0-flash
- name: gemini/gemini-2.5-pro-exp-03-25 - name: gemini/gemini-2.5-pro-exp-03-25
edit_format: diff-fenced edit_format: diff-fenced
use_repo_map: true use_repo_map: true

View file

@ -4,6 +4,7 @@ import re
import time import time
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch
import git import git
@ -18,7 +19,167 @@ class TestRepoMap(unittest.TestCase):
def setUp(self): def setUp(self):
self.GPT35 = Model("gpt-3.5-turbo") self.GPT35 = Model("gpt-3.5-turbo")
# Helper function to calculate MD5 hash of a file
def _calculate_md5_for_file(self, file_path):
import hashlib
hasher = hashlib.md5()
with open(file_path, 'rb') as f:
while True:
chunk = f.read(8192)
if not chunk:
break
hasher.update(chunk)
return hasher.hexdigest()
def test_get_repo_map(self): def test_get_repo_map(self):
pass
@patch("aider.io.InputOutput.tool_warning")
def test_get_tags_md5_change_same_mtime(self, mock_tool_warning):
"""Verify MD5 detection when mtime is unchanged."""
with GitTemporaryDirectory() as temp_dir:
# Create a test file
test_file = Path(temp_dir) / "test.py"
initial_content = "def func_a(): pass\n"
test_file.write_text(initial_content)
abs_path = str(test_file.resolve())
rel_path = "test.py"
# Initialize RepoMap and populate cache
io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
initial_tags = repo_map.get_tags(abs_path, rel_path)
self.assertTrue(any(tag.name == "func_a" for tag in initial_tags))
initial_mtime = os.path.getmtime(abs_path)
# Modify content, reset mtime
new_content = "def func_b(): pass\n"
test_file.write_text(new_content)
os.utime(abs_path, (initial_mtime, initial_mtime)) # Reset mtime
# Call get_tags again
new_tags = repo_map.get_tags(abs_path, rel_path)
# Assertions
mock_tool_warning.assert_called_once()
self.assertIn("MD5 mismatch", mock_tool_warning.call_args[0][0])
self.assertTrue(any(tag.name == "func_b" for tag in new_tags))
self.assertFalse(any(tag.name == "func_a" for tag in new_tags))
# Check cache update
cached_data = repo_map.TAGS_CACHE.get(abs_path)
self.assertIsNotNone(cached_data)
expected_md5 = self._calculate_md5_for_file(abs_path)
self.assertEqual(cached_data.get("md5"), expected_md5)
self.assertEqual(cached_data.get("mtime"), initial_mtime)
del repo_map # Close cache
@patch("aider.io.InputOutput.tool_warning")
@patch("aider.repomap.RepoMap.get_tags_raw")
def test_get_tags_no_change(self, mock_get_tags_raw, mock_tool_warning):
"""Verify cache is used when file is unchanged."""
with GitTemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.py"
initial_content = "def func_a(): pass\n"
test_file.write_text(initial_content)
abs_path = str(test_file.resolve())
rel_path = "test.py"
io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
# Initial call to populate cache
initial_tags = repo_map.get_tags(abs_path, rel_path)
mock_get_tags_raw.assert_called_once() # Called once initially
mock_get_tags_raw.reset_mock() # Reset for the next check
# Call get_tags again without changes
second_tags = repo_map.get_tags(abs_path, rel_path)
# Assertions
mock_tool_warning.assert_not_called()
mock_get_tags_raw.assert_not_called() # Should not be called again
self.assertEqual(initial_tags, second_tags)
del repo_map # Close cache
@patch("aider.io.InputOutput.tool_warning")
def test_get_tags_mtime_change(self, mock_tool_warning):
"""Verify standard mtime-based change detection still works."""
with GitTemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.py"
initial_content = "def func_a(): pass\n"
test_file.write_text(initial_content)
abs_path = str(test_file.resolve())
rel_path = "test.py"
io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
# Initial call
initial_tags = repo_map.get_tags(abs_path, rel_path)
self.assertTrue(any(tag.name == "func_a" for tag in initial_tags))
# Modify content (mtime will change naturally)
time.sleep(0.01) # Ensure mtime is different
new_content = "def func_b(): pass\n"
test_file.write_text(new_content)
new_mtime = os.path.getmtime(abs_path)
# Call get_tags again
new_tags = repo_map.get_tags(abs_path, rel_path)
# Assertions
mock_tool_warning.assert_called_once()
self.assertIn("mtime mismatch", mock_tool_warning.call_args[0][0])
self.assertTrue(any(tag.name == "func_b" for tag in new_tags))
self.assertFalse(any(tag.name == "func_a" for tag in new_tags))
# Check cache update
cached_data = repo_map.TAGS_CACHE.get(abs_path)
self.assertIsNotNone(cached_data)
expected_md5 = self._calculate_md5_for_file(abs_path)
self.assertEqual(cached_data.get("md5"), expected_md5)
self.assertEqual(cached_data.get("mtime"), new_mtime)
del repo_map # Close cache
@patch("aider.io.InputOutput.tool_warning")
def test_get_tags_file_not_found_after_cache(self, mock_tool_warning):
"""Verify graceful handling if a cached file becomes inaccessible."""
with GitTemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.py"
test_file.write_text("def func_a(): pass\n")
abs_path = str(test_file.resolve())
rel_path = "test.py"
io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
# Populate cache
repo_map.get_tags(abs_path, rel_path)
self.assertIn(abs_path, repo_map.TAGS_CACHE)
# Delete the file
os.remove(abs_path)
# Call get_tags again
result = repo_map.get_tags(abs_path, rel_path)
# Assertions
mock_tool_warning.assert_called()
# Check if any call contains "Error accessing file" or "FileNotFoundError"
warning_found = any(
"Error accessing file" in call[0][0] or "FileNotFoundError" in call[0][0]
for call in mock_tool_warning.call_args_list
)
self.assertTrue(warning_found, "Expected file access error warning not found")
self.assertEqual(result, [])
self.assertNotIn(abs_path, repo_map.TAGS_CACHE)
del repo_map # Close cache
# Create a temporary directory with sample files for testing # Create a temporary directory with sample files for testing
test_files = [ test_files = [
"test_file1.py", "test_file1.py",