Properly handle all diff cases

This commit is contained in:
Paul Gauthier 2023-08-18 10:07:47 -07:00
parent 7af82ba185
commit 285536105e
3 changed files with 24 additions and 33 deletions

View file

@ -230,7 +230,7 @@ class Commands:
return
commits = f"{self.coder.last_aider_commit_hash}~1"
diff = self.coder.repo.get_diffs(
diff = self.coder.repo.diff_commits(
self.coder.pretty,
commits,
self.coder.last_aider_commit_hash,

View file

@ -56,12 +56,7 @@ class GitRepo:
if message:
commit_message = message
else:
diff_args = []
if fnames:
diff_args += ["--"] + list(fnames)
dump(diff_args)
diffs = self.get_diffs(False, *diff_args)
dump(diffs)
diffs = self.get_diffs(fnames)
commit_message = self.get_commit_message(diffs, context)
if not commit_message:
@ -129,42 +124,38 @@ class GitRepo:
return commit_message
def get_diffs(self, pretty, *args):
args = list(args)
# if args are specified, just add --pretty if needed
if args:
if pretty:
args = ["--color"] + args
dump(args)
return self.repo.git.diff(*args)
# otherwise, we always want diffs of index and working dir
def get_diffs(self, fnames=None):
# We always want diffs of index and working dir
try:
commits = self.repo.iter_commits(self.repo.active_branch)
current_branch_has_commits = any(commits)
except git.exc.GitCommandError:
current_branch_has_commits = False
if pretty:
args = ["--color"]
if not fnames:
fnames = []
if current_branch_has_commits:
# if there is a HEAD, just diff against it to pick up index + working
args += ["HEAD"]
args = ["HEAD", "--"] + list(fnames)
return self.repo.git.diff(*args)
# diffs in the index
diffs = self.repo.git.diff(*(args + ["--cached"]))
# plus, diffs in the working dir
diffs += self.repo.git.diff(*args)
wd_args = ["--"] + list(fnames)
index_args = ["--cached"] + wd_args
diffs = self.repo.git.diff(*index_args)
diffs += self.repo.git.diff(*wd_args)
return diffs
def show_diffs(self, pretty):
diffs = self.get_diffs(pretty)
print(diffs)
def diff_commits(self, pretty, from_commit, to_commit):
args = []
if pretty:
args += ["--color"]
args += [from_commit, to_commit]
diffs = self.repo.git.diff(*args)
return diffs
def get_tracked_files(self):
if not self.repo:

View file

@ -26,7 +26,7 @@ class TestRepo(unittest.TestCase):
fname.write_text("workingdir\n")
git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False)
diffs = git_repo.get_diffs()
self.assertIn("index", diffs)
self.assertIn("workingdir", diffs)
@ -49,7 +49,7 @@ class TestRepo(unittest.TestCase):
fname2.write_text("workingdir\n")
git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False)
diffs = git_repo.get_diffs()
self.assertIn("index", diffs)
self.assertIn("workingdir", diffs)
@ -67,7 +67,7 @@ class TestRepo(unittest.TestCase):
repo.git.commit("-m", "second")
git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False, ["HEAD~1", "HEAD"])
diffs = git_repo.diff_commits(False, "HEAD~1", "HEAD")
dump(diffs)
self.assertIn("two", diffs)