From 285536105ebb5df5b581568848241f8cdb02ebaa Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Fri, 18 Aug 2023 10:07:47 -0700 Subject: [PATCH] Properly handle all diff cases --- aider/commands.py | 2 +- aider/repo.py | 49 +++++++++++++++++++--------------------------- tests/test_repo.py | 6 +++--- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/aider/commands.py b/aider/commands.py index af0fa332e..6047e9438 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -230,7 +230,7 @@ class Commands: return commits = f"{self.coder.last_aider_commit_hash}~1" - diff = self.coder.repo.get_diffs( + diff = self.coder.repo.diff_commits( self.coder.pretty, commits, self.coder.last_aider_commit_hash, diff --git a/aider/repo.py b/aider/repo.py index 4e6f73e4b..92cfb1cb2 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -56,12 +56,7 @@ class GitRepo: if message: commit_message = message else: - diff_args = [] - if fnames: - diff_args += ["--"] + list(fnames) - dump(diff_args) - diffs = self.get_diffs(False, *diff_args) - dump(diffs) + diffs = self.get_diffs(fnames) commit_message = self.get_commit_message(diffs, context) if not commit_message: @@ -129,42 +124,38 @@ class GitRepo: return commit_message - def get_diffs(self, pretty, *args): - args = list(args) - - # if args are specified, just add --pretty if needed - if args: - if pretty: - args = ["--color"] + args - dump(args) - return self.repo.git.diff(*args) - - # otherwise, we always want diffs of index and working dir - + def get_diffs(self, fnames=None): + # We always want diffs of index and working dir try: commits = self.repo.iter_commits(self.repo.active_branch) current_branch_has_commits = any(commits) except git.exc.GitCommandError: current_branch_has_commits = False - if pretty: - args = ["--color"] + if not fnames: + fnames = [] if current_branch_has_commits: - # if there is a HEAD, just diff against it to pick up index + working - args += ["HEAD"] + args = ["HEAD", "--"] + list(fnames) return self.repo.git.diff(*args) - # diffs in the index - diffs = self.repo.git.diff(*(args + ["--cached"])) - # plus, diffs in the working dir - diffs += self.repo.git.diff(*args) + wd_args = ["--"] + list(fnames) + index_args = ["--cached"] + wd_args + + diffs = self.repo.git.diff(*index_args) + diffs += self.repo.git.diff(*wd_args) return diffs - def show_diffs(self, pretty): - diffs = self.get_diffs(pretty) - print(diffs) + def diff_commits(self, pretty, from_commit, to_commit): + args = [] + if pretty: + args += ["--color"] + + args += [from_commit, to_commit] + diffs = self.repo.git.diff(*args) + + return diffs def get_tracked_files(self): if not self.repo: diff --git a/tests/test_repo.py b/tests/test_repo.py index aee78abdf..77092c2da 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -26,7 +26,7 @@ class TestRepo(unittest.TestCase): fname.write_text("workingdir\n") git_repo = GitRepo(InputOutput(), None, ".") - diffs = git_repo.get_diffs(False) + diffs = git_repo.get_diffs() self.assertIn("index", diffs) self.assertIn("workingdir", diffs) @@ -49,7 +49,7 @@ class TestRepo(unittest.TestCase): fname2.write_text("workingdir\n") git_repo = GitRepo(InputOutput(), None, ".") - diffs = git_repo.get_diffs(False) + diffs = git_repo.get_diffs() self.assertIn("index", diffs) self.assertIn("workingdir", diffs) @@ -67,7 +67,7 @@ class TestRepo(unittest.TestCase): repo.git.commit("-m", "second") git_repo = GitRepo(InputOutput(), None, ".") - diffs = git_repo.get_diffs(False, ["HEAD~1", "HEAD"]) + diffs = git_repo.diff_commits(False, "HEAD~1", "HEAD") dump(diffs) self.assertIn("two", diffs)