use safe repo.get_head methods

This commit is contained in:
Paul Gauthier 2024-08-31 07:29:51 -07:00
parent 8678a6455f
commit d2acb9c3b0
3 changed files with 28 additions and 12 deletions

View file

@ -713,7 +713,7 @@ class Coder:
self.shell_commands = [] self.shell_commands = []
if self.repo: if self.repo:
self.commit_before_message.append(self.repo.get_head()) self.commit_before_message.append(self.repo.get_head_sha())
def run(self, with_message=None, preproc=True): def run(self, with_message=None, preproc=True):
try: try:
@ -1867,7 +1867,7 @@ class Coder:
def show_undo_hint(self): def show_undo_hint(self):
if not self.commit_before_message: if not self.commit_before_message:
return return
if self.commit_before_message[-1] != self.repo.get_head(): if self.commit_before_message[-1] != self.repo.get_head_sha():
self.io.tool_output("You can use /undo to undo and discard each aider commit.") self.io.tool_output("You can use /undo to undo and discard each aider commit.")
def dirty_commit(self): def dirty_commit(self):

View file

@ -420,8 +420,8 @@ class Commands:
self.io.tool_error("No git repository found.") self.io.tool_error("No git repository found.")
return return
last_commit = self.coder.repo.repo.head.commit last_commit = self.coder.repo.get_head()
if not last_commit.parents: if last_commit and not last_commit.parents:
self.io.tool_error("This is the first commit in the repository. Cannot undo.") self.io.tool_error("This is the first commit in the repository. Cannot undo.")
return return
@ -461,8 +461,9 @@ class Commands:
) )
return return
last_commit_hash = self.coder.repo.repo.head.commit.hexsha[:7] last_commit_hash = self.coder.repo.get_head_sha(short=True)
last_commit_message = self.coder.repo.repo.head.commit.message.strip() last_commit_message = self.coder.repo.get_head_message("(unknown)")
if last_commit_hash not in self.coder.aider_commit_hashes: if last_commit_hash not in self.coder.aider_commit_hashes:
self.io.tool_error("The last commit was not made by aider in this chat session.") self.io.tool_error("The last commit was not made by aider in this chat session.")
self.io.tool_error( self.io.tool_error(
@ -481,8 +482,8 @@ class Commands:
self.io.tool_output(f"Removed: {last_commit_hash} {last_commit_message}") self.io.tool_output(f"Removed: {last_commit_hash} {last_commit_message}")
# Get the current HEAD after undo # Get the current HEAD after undo
current_head_hash = self.coder.repo.repo.head.commit.hexsha[:7] current_head_hash = self.coder.repo.get_head_sha(short=True)
current_head_message = self.coder.repo.repo.head.commit.message.strip() current_head_message = self.coder.repo.get_head_message("(unknown)")
self.io.tool_output(f"Now at: {current_head_hash} {current_head_message}") self.io.tool_output(f"Now at: {current_head_hash} {current_head_message}")
if self.coder.main_model.send_undo_reply: if self.coder.main_model.send_undo_reply:
@ -494,7 +495,7 @@ class Commands:
self.io.tool_error("No git repository found.") self.io.tool_error("No git repository found.")
return return
current_head = self.coder.repo.get_head() current_head = self.coder.repo.get_head_sha()
if current_head is None: if current_head is None:
self.io.tool_error("Unable to get current commit. The repository might be empty.") self.io.tool_error("Unable to get current commit. The repository might be empty.")
return return

View file

@ -3,6 +3,7 @@ import time
from pathlib import Path, PurePosixPath from pathlib import Path, PurePosixPath
import git import git
import gitdb
import pathspec import pathspec
from aider import prompts, utils from aider import prompts, utils
@ -137,7 +138,7 @@ class GitRepo:
os.environ["GIT_AUTHOR_NAME"] = committer_name os.environ["GIT_AUTHOR_NAME"] = committer_name
self.repo.git.commit(cmd) self.repo.git.commit(cmd)
commit_hash = self.repo.head.commit.hexsha[:7] commit_hash = self.get_head_sha(short=True)
self.io.tool_output(f"Commit {commit_hash} {commit_message}", bold=True) self.io.tool_output(f"Commit {commit_hash} {commit_message}", bold=True)
# Restore the env # Restore the env
@ -374,6 +375,20 @@ class GitRepo:
def get_head(self): def get_head(self):
try: try:
return self.repo.head.commit.hexsha return self.repo.head.commit
except ValueError: except (ValueError, gitdb.exc.ODBError):
return None return None
def get_head_sha(self, short=False):
commit = self.get_head()
if not commit:
return
if short:
return commit.hexsha[:7]
return commit.hexsha
def get_head_message(self, default=None):
commit = self.get_head()
if not commit:
return default
return commit.message