Properly handle all diff cases

This commit is contained in:
Paul Gauthier 2023-08-18 10:07:47 -07:00
parent 7af82ba185
commit 285536105e
3 changed files with 24 additions and 33 deletions

View file

@ -230,7 +230,7 @@ class Commands:
return return
commits = f"{self.coder.last_aider_commit_hash}~1" commits = f"{self.coder.last_aider_commit_hash}~1"
diff = self.coder.repo.get_diffs( diff = self.coder.repo.diff_commits(
self.coder.pretty, self.coder.pretty,
commits, commits,
self.coder.last_aider_commit_hash, self.coder.last_aider_commit_hash,

View file

@ -56,12 +56,7 @@ class GitRepo:
if message: if message:
commit_message = message commit_message = message
else: else:
diff_args = [] diffs = self.get_diffs(fnames)
if fnames:
diff_args += ["--"] + list(fnames)
dump(diff_args)
diffs = self.get_diffs(False, *diff_args)
dump(diffs)
commit_message = self.get_commit_message(diffs, context) commit_message = self.get_commit_message(diffs, context)
if not commit_message: if not commit_message:
@ -129,42 +124,38 @@ class GitRepo:
return commit_message return commit_message
def get_diffs(self, pretty, *args): def get_diffs(self, fnames=None):
args = list(args) # We always want diffs of index and working dir
# if args are specified, just add --pretty if needed
if args:
if pretty:
args = ["--color"] + args
dump(args)
return self.repo.git.diff(*args)
# otherwise, we always want diffs of index and working dir
try: try:
commits = self.repo.iter_commits(self.repo.active_branch) commits = self.repo.iter_commits(self.repo.active_branch)
current_branch_has_commits = any(commits) current_branch_has_commits = any(commits)
except git.exc.GitCommandError: except git.exc.GitCommandError:
current_branch_has_commits = False current_branch_has_commits = False
if pretty: if not fnames:
args = ["--color"] fnames = []
if current_branch_has_commits: if current_branch_has_commits:
# if there is a HEAD, just diff against it to pick up index + working args = ["HEAD", "--"] + list(fnames)
args += ["HEAD"]
return self.repo.git.diff(*args) return self.repo.git.diff(*args)
# diffs in the index wd_args = ["--"] + list(fnames)
diffs = self.repo.git.diff(*(args + ["--cached"])) index_args = ["--cached"] + wd_args
# plus, diffs in the working dir
diffs += self.repo.git.diff(*args) diffs = self.repo.git.diff(*index_args)
diffs += self.repo.git.diff(*wd_args)
return diffs return diffs
def show_diffs(self, pretty): def diff_commits(self, pretty, from_commit, to_commit):
diffs = self.get_diffs(pretty) args = []
print(diffs) if pretty:
args += ["--color"]
args += [from_commit, to_commit]
diffs = self.repo.git.diff(*args)
return diffs
def get_tracked_files(self): def get_tracked_files(self):
if not self.repo: if not self.repo:

View file

@ -26,7 +26,7 @@ class TestRepo(unittest.TestCase):
fname.write_text("workingdir\n") fname.write_text("workingdir\n")
git_repo = GitRepo(InputOutput(), None, ".") git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False) diffs = git_repo.get_diffs()
self.assertIn("index", diffs) self.assertIn("index", diffs)
self.assertIn("workingdir", diffs) self.assertIn("workingdir", diffs)
@ -49,7 +49,7 @@ class TestRepo(unittest.TestCase):
fname2.write_text("workingdir\n") fname2.write_text("workingdir\n")
git_repo = GitRepo(InputOutput(), None, ".") git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False) diffs = git_repo.get_diffs()
self.assertIn("index", diffs) self.assertIn("index", diffs)
self.assertIn("workingdir", diffs) self.assertIn("workingdir", diffs)
@ -67,7 +67,7 @@ class TestRepo(unittest.TestCase):
repo.git.commit("-m", "second") repo.git.commit("-m", "second")
git_repo = GitRepo(InputOutput(), None, ".") git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False, ["HEAD~1", "HEAD"]) diffs = git_repo.diff_commits(False, "HEAD~1", "HEAD")
dump(diffs) dump(diffs)
self.assertIn("two", diffs) self.assertIn("two", diffs)