aider/tests/basic/test_sanity_check_repo.py
2024-10-30 11:59:52 -07:00

184 lines
5.8 KiB
Python

import os
import shutil
import struct
import webbrowser
from unittest import mock
import pytest
from git import GitError, Repo
from aider import urls
from aider.main import sanity_check_repo
@pytest.fixture
def mock_io():
"""Fixture to create a mock io object."""
return mock.Mock()
@pytest.fixture
def create_repo(tmp_path):
"""
Fixture to create a standard Git repository.
Returns the path to the repo and the Repo object.
"""
repo_path = tmp_path / "test_repo"
repo = Repo.init(repo_path)
# Create an initial commit
file_path = repo_path / "README.md"
file_path.write_text("# Test Repository")
repo.index.add([str(file_path.relative_to(repo_path))])
repo.index.commit("Initial commit")
return repo_path, repo
def set_git_index_version(repo_path, version):
"""
Sets the Git index version by modifying the .git/index file.
The index version is stored in the first 4 bytes as a little-endian integer.
"""
index_path = os.path.join(repo_path, ".git", "index")
with open(index_path, "r+b") as f:
# Read the first 4 bytes (signature) and the next 4 bytes (version)
signature = f.read(4)
if signature != b"DIRC":
raise ValueError("Invalid git index file signature.")
# Write the new version
f.seek(4)
f.write(struct.pack("<I", version))
def detach_head(repo):
"""
Detaches the HEAD of the repository by checking out the current commit hash.
"""
current_commit = repo.head.commit
repo.git.checkout(current_commit.hexsha)
def mock_repo_wrapper(repo_obj, git_repo_error=None):
"""
Creates a mock 'repo' object to pass to sanity_check_repo.
The mock object has:
- repo.repo: the Repo object
- repo.get_tracked_files(): returns a list of tracked files or raises GitError
- repo.git_repo_error: the GitError if any
"""
mock_repo = mock.Mock()
mock_repo.repo = repo_obj
if git_repo_error:
def get_tracked_files_side_effect():
raise git_repo_error
mock_repo.get_tracked_files.side_effect = get_tracked_files_side_effect
mock_repo.git_repo_error = git_repo_error
else:
mock_repo.get_tracked_files.return_value = [
str(path) for path in repo_obj.git.ls_files().splitlines()
]
mock_repo.git_repo_error = None
return mock_repo
def test_detached_head_state(create_repo, mock_io):
repo_path, repo = create_repo
# Detach the HEAD
detach_head(repo)
# Create the mock 'repo' object
mock_repo_obj = mock_repo_wrapper(repo)
# Call the function
result = sanity_check_repo(mock_repo_obj, mock_io)
# Assert that the function returns True
assert result is True
# Assert that no errors were logged
mock_io.tool_error.assert_not_called()
mock_io.tool_output.assert_not_called()
@mock.patch("webbrowser.open")
def test_git_index_version_greater_than_2(mock_browser, create_repo, mock_io):
repo_path, repo = create_repo
# Set the git index version to 3
set_git_index_version(str(repo_path), 3)
# Simulate that get_tracked_files raises an error due to index version
git_error = GitError("index version in (1, 2) is required")
mock_repo_obj = mock_repo_wrapper(repo, git_repo_error=git_error)
# Call the function
result = sanity_check_repo(mock_repo_obj, mock_io)
# Assert that the function returns False
assert result is False
# Assert that the appropriate error messages were logged
mock_io.tool_error.assert_called_with(
"Aider only works with git repos with version number 1 or 2."
)
mock_io.tool_error.assert_any_call(
"Aider only works with git repos with version number 1 or 2."
)
mock_io.tool_output.assert_any_call(
"You may be able to convert your repo: git update-index --index-version=2"
)
mock_io.tool_output.assert_any_call("Or run aider --no-git to proceed without using git.")
mock_io.confirm_ask.assert_any_call(
"Open documentation url for more info?", subject=urls.git_index_version
)
def test_bare_repository(create_repo, mock_io, tmp_path):
# Initialize a bare repository
bare_repo_path = tmp_path / "bare_repo.git"
bare_repo = Repo.init(bare_repo_path, bare=True)
# Create the mock 'repo' object
mock_repo_obj = mock_repo_wrapper(bare_repo)
# Call the function
result = sanity_check_repo(mock_repo_obj, mock_io)
# Assert that the function returns False
assert result is False
# Assert that the appropriate error message was logged
mock_io.tool_error.assert_called_with("The git repo does not seem to have a working tree?")
mock_io.tool_output.assert_not_called()
def test_sanity_check_repo_with_corrupt_repo(create_repo, mock_io):
repo_path, repo = create_repo
# Simulate a corrupt repository by removing the .git directory
shutil.rmtree(os.path.join(repo_path, ".git"))
# Create the mock 'repo' object with GitError
git_error = GitError("Unable to read git repository, it may be corrupt?")
mock_repo_obj = mock_repo_wrapper(repo, git_repo_error=git_error)
# Call the function
result = sanity_check_repo(mock_repo_obj, mock_io)
# Assert that the function returns False
assert result is False
# Assert that the appropriate error messages were logged
mock_io.tool_error.assert_called_with("Unable to read git repository, it may be corrupt?")
mock_io.tool_output.assert_called_with(str(git_error))
def test_sanity_check_repo_with_no_repo(mock_io):
# Call the function with repo=None
result = sanity_check_repo(None, mock_io)
# Assert that the function returns True
assert result is True
# Assert that no errors or outputs were logged
mock_io.tool_error.assert_not_called()
mock_io.tool_output.assert_not_called()