refactor: simplify error handling in get_tracked_files method

This commit is contained in:
Paul Gauthier 2024-09-02 16:03:54 -07:00 committed by Paul Gauthier (aider)
parent 387df7f1db
commit c8b2024f8b
2 changed files with 21 additions and 29 deletions

View file

@ -17,7 +17,7 @@ from aider.format_settings import format_settings, scrub_sensitive_info
from aider.history import ChatSummary from aider.history import ChatSummary
from aider.io import InputOutput from aider.io import InputOutput
from aider.llm import litellm # noqa: F401; properly init litellm on launch from aider.llm import litellm # noqa: F401; properly init litellm on launch
from aider.repo import ANY_GIT_ERROR, GitRepo, UnableToCountRepoFiles from aider.repo import ANY_GIT_ERROR, GitRepo
from aider.report import report_uncaught_exceptions from aider.report import report_uncaught_exceptions
from aider.versioncheck import check_version, install_from_main_branch, install_upgrade from aider.versioncheck import check_version, install_from_main_branch, install_upgrade
@ -301,7 +301,7 @@ def sanity_check_repo(repo, io):
try: try:
repo.get_tracked_files() repo.get_tracked_files()
return True return True
except UnableToCountRepoFiles as exc: except ANY_GIT_ERROR as exc:
error_msg = str(exc) error_msg = str(exc)
if "version in (1, 2)" in error_msg: if "version in (1, 2)" in error_msg:

View file

@ -10,11 +10,6 @@ from aider.sendchat import simple_send_with_retries
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
class UnableToCountRepoFiles(Exception):
pass
ANY_GIT_ERROR = (git.exc.ODBError, git.exc.GitError) ANY_GIT_ERROR = (git.exc.ODBError, git.exc.GitError)
@ -259,32 +254,29 @@ class GitRepo:
return [] return []
try: try:
try: commit = self.repo.head.commit
commit = self.repo.head.commit except ValueError:
except ValueError: commit = None
commit = None
files = set() files = set()
if commit: if commit:
if commit in self.tree_files: if commit in self.tree_files:
files = self.tree_files[commit] files = self.tree_files[commit]
else: else:
for blob in commit.tree.traverse(): for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file if blob.type == "blob": # blob is a file
files.add(blob.path) files.add(blob.path)
files = set(self.normalize_path(path) for path in files) files = set(self.normalize_path(path) for path in files)
self.tree_files[commit] = set(files) self.tree_files[commit] = set(files)
# Add staged files # Add staged files
index = self.repo.index index = self.repo.index
staged_files = [path for path, _ in index.entries.keys()] staged_files = [path for path, _ in index.entries.keys()]
files.update(self.normalize_path(path) for path in staged_files) files.update(self.normalize_path(path) for path in staged_files)
res = [fname for fname in files if not self.ignored_file(fname)] res = [fname for fname in files if not self.ignored_file(fname)]
return res return res
except Exception as e:
raise UnableToCountRepoFiles(f"Error getting tracked files: {str(e)}")
def normalize_path(self, path): def normalize_path(self, path):
orig_path = path orig_path = path