diff --git a/aider/repo.py b/aider/repo.py index 28b3c700d..040f8156b 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -126,13 +126,35 @@ class GitRepo: return commit_message def get_diffs(self, pretty, *args): - # we always want diffs of working-dir + index versus repo - args = ["--cached"] + list(args) + args = list(args) + + # if args are specified, just add --pretty if needed + if args: + if pretty: + args = ["--color"] + args + return self.repo.git.diff(*args) + + # otherwise, we always want diffs of index and working dir + + try: + commits = self.repo.iter_commits(self.repo.active_branch) + current_branch_has_commits = any(commits) + except git.exc.GitCommandError: + current_branch_has_commits = False if pretty: - args = ["--color"] + list(args) + args = ["--color"] + + if current_branch_has_commits: + # if there is a HEAD, just diff against it to pick up index + working + args += ["HEAD"] + return self.repo.git.diff(*args) + + # diffs in the index + diffs = self.repo.git.diff(*(args + ["--cached"])) + # plus, diffs in the working dir + diffs += self.repo.git.diff(*args) - diffs = self.repo.git.diff(*args) return diffs def show_diffs(self, pretty): diff --git a/tests/test_repo.py b/tests/test_repo.py index df2fdceaa..8451ec309 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -16,15 +16,60 @@ class TestRepo(unittest.TestCase): def test_diffs_empty_repo(self): with GitTemporaryDirectory(): repo = git.Repo() - fname = Path("foo.txt") - fname.touch() + # Add a change to the index + fname = Path("foo.txt") + fname.write_text("index\n") repo.git.add(str(fname)) + # Make a change in the working dir + fname.write_text("workingdir\n") + git_repo = GitRepo(InputOutput(), None, ".") diffs = git_repo.get_diffs(False) - self.assertNotEqual(diffs, "") - self.assertIsNotNone(diffs) + self.assertIn("index", diffs) + self.assertIn("workingdir", diffs) + + def test_diffs_nonempty_repo(self): + with GitTemporaryDirectory(): + repo = git.Repo() + fname = Path("foo.txt") + fname.touch() + repo.git.add(str(fname)) + + fname2 = Path("bar.txt") + fname2.touch() + repo.git.add(str(fname2)) + + repo.git.commit("-m", "initial") + + fname.write_text("index\n") + repo.git.add(str(fname)) + + fname2.write_text("workingdir\n") + + git_repo = GitRepo(InputOutput(), None, ".") + diffs = git_repo.get_diffs(False) + self.assertIn("index", diffs) + self.assertIn("workingdir", diffs) + + def test_diffs_between_commits(self): + with GitTemporaryDirectory(): + repo = git.Repo() + fname = Path("foo.txt") + + fname.write_text("one\n") + repo.git.add(str(fname)) + repo.git.commit("-m", "initial") + + fname.write_text("two\n") + repo.git.add(str(fname)) + repo.git.commit("-m", "second") + + git_repo = GitRepo(InputOutput(), None, ".") + diffs = git_repo.get_diffs(False, ["HEAD~1", "HEAD"]) + dump(diffs) + self.assertIn("two", diffs) @patch("aider.repo.simple_send_with_retries") def test_get_commit_message(self, mock_send):