From c8b2024f8ba8967a0f22e8229cd70e7ba37e30b1 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Mon, 2 Sep 2024 16:03:54 -0700 Subject: [PATCH] refactor: simplify error handling in get_tracked_files method --- aider/main.py | 4 ++-- aider/repo.py | 46 +++++++++++++++++++--------------------------- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/aider/main.py b/aider/main.py index 9e9378323..fa02e28ed 100644 --- a/aider/main.py +++ b/aider/main.py @@ -17,7 +17,7 @@ from aider.format_settings import format_settings, scrub_sensitive_info from aider.history import ChatSummary from aider.io import InputOutput from aider.llm import litellm # noqa: F401; properly init litellm on launch -from aider.repo import ANY_GIT_ERROR, GitRepo, UnableToCountRepoFiles +from aider.repo import ANY_GIT_ERROR, GitRepo from aider.report import report_uncaught_exceptions from aider.versioncheck import check_version, install_from_main_branch, install_upgrade @@ -301,7 +301,7 @@ def sanity_check_repo(repo, io): try: repo.get_tracked_files() return True - except UnableToCountRepoFiles as exc: + except ANY_GIT_ERROR as exc: error_msg = str(exc) if "version in (1, 2)" in error_msg: diff --git a/aider/repo.py b/aider/repo.py index fc3548e4f..96f0639f1 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -10,11 +10,6 @@ from aider.sendchat import simple_send_with_retries from .dump import dump # noqa: F401 - -class UnableToCountRepoFiles(Exception): - pass - - ANY_GIT_ERROR = (git.exc.ODBError, git.exc.GitError) @@ -259,32 +254,29 @@ class GitRepo: return [] try: - try: - commit = self.repo.head.commit - except ValueError: - commit = None + commit = self.repo.head.commit + except ValueError: + commit = None - files = set() - if commit: - if commit in self.tree_files: - files = self.tree_files[commit] - else: - for blob in commit.tree.traverse(): - if blob.type == "blob": # blob is a file - files.add(blob.path) - files = set(self.normalize_path(path) for path in files) - self.tree_files[commit] = set(files) + files = set() + if commit: + if commit in self.tree_files: + files = self.tree_files[commit] + else: + for blob in commit.tree.traverse(): + if blob.type == "blob": # blob is a file + files.add(blob.path) + files = set(self.normalize_path(path) for path in files) + self.tree_files[commit] = set(files) - # Add staged files - index = self.repo.index - staged_files = [path for path, _ in index.entries.keys()] - files.update(self.normalize_path(path) for path in staged_files) + # Add staged files + index = self.repo.index + staged_files = [path for path, _ in index.entries.keys()] + files.update(self.normalize_path(path) for path in staged_files) - res = [fname for fname in files if not self.ignored_file(fname)] + res = [fname for fname in files if not self.ignored_file(fname)] - return res - except Exception as e: - raise UnableToCountRepoFiles(f"Error getting tracked files: {str(e)}") + return res def normalize_path(self, path): orig_path = path