refactor: simplify error handling in get_tracked_files method

This commit is contained in:
Paul Gauthier 2024-09-02 16:03:54 -07:00 committed by Paul Gauthier (aider)
parent 387df7f1db
commit c8b2024f8b
2 changed files with 21 additions and 29 deletions

View file

@ -17,7 +17,7 @@ from aider.format_settings import format_settings, scrub_sensitive_info
from aider.history import ChatSummary
from aider.io import InputOutput
from aider.llm import litellm # noqa: F401; properly init litellm on launch
from aider.repo import ANY_GIT_ERROR, GitRepo, UnableToCountRepoFiles
from aider.repo import ANY_GIT_ERROR, GitRepo
from aider.report import report_uncaught_exceptions
from aider.versioncheck import check_version, install_from_main_branch, install_upgrade
@ -301,7 +301,7 @@ def sanity_check_repo(repo, io):
try:
repo.get_tracked_files()
return True
except UnableToCountRepoFiles as exc:
except ANY_GIT_ERROR as exc:
error_msg = str(exc)
if "version in (1, 2)" in error_msg:

View file

@ -10,11 +10,6 @@ from aider.sendchat import simple_send_with_retries
from .dump import dump # noqa: F401
class UnableToCountRepoFiles(Exception):
pass
ANY_GIT_ERROR = (git.exc.ODBError, git.exc.GitError)
@ -259,32 +254,29 @@ class GitRepo:
return []
try:
try:
commit = self.repo.head.commit
except ValueError:
commit = None
commit = self.repo.head.commit
except ValueError:
commit = None
files = set()
if commit:
if commit in self.tree_files:
files = self.tree_files[commit]
else:
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.add(blob.path)
files = set(self.normalize_path(path) for path in files)
self.tree_files[commit] = set(files)
files = set()
if commit:
if commit in self.tree_files:
files = self.tree_files[commit]
else:
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.add(blob.path)
files = set(self.normalize_path(path) for path in files)
self.tree_files[commit] = set(files)
# Add staged files
index = self.repo.index
staged_files = [path for path, _ in index.entries.keys()]
files.update(self.normalize_path(path) for path in staged_files)
# Add staged files
index = self.repo.index
staged_files = [path for path, _ in index.entries.keys()]
files.update(self.normalize_path(path) for path in staged_files)
res = [fname for fname in files if not self.ignored_file(fname)]
res = [fname for fname in files if not self.ignored_file(fname)]
return res
except Exception as e:
raise UnableToCountRepoFiles(f"Error getting tracked files: {str(e)}")
return res
def normalize_path(self, path):
orig_path = path