mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-21 21:04:59 +00:00
346 lines
10 KiB
Python
346 lines
10 KiB
Python
import os
|
|
import time
|
|
from pathlib import Path, PurePosixPath
|
|
|
|
import git
|
|
import pathspec
|
|
|
|
from aider import prompts, utils
|
|
from aider.sendchat import send_with_retries
|
|
|
|
from .dump import dump # noqa: F401
|
|
|
|
|
|
class GitRepo:
|
|
repo = None
|
|
aider_ignore_file = None
|
|
aider_ignore_spec = None
|
|
aider_ignore_ts = 0
|
|
aider_ignore_last_check = 0
|
|
subtree_only = False
|
|
ignore_file_cache = {}
|
|
|
|
def __init__(
|
|
self,
|
|
io,
|
|
fnames,
|
|
git_dname,
|
|
aider_ignore_file=None,
|
|
models=None,
|
|
attribute_author=True,
|
|
attribute_committer=True,
|
|
attribute_commit_message=False,
|
|
commit_prompt=None,
|
|
subtree_only=False,
|
|
):
|
|
self.io = io
|
|
self.models = models
|
|
|
|
self.attribute_author = attribute_author
|
|
self.attribute_committer = attribute_committer
|
|
self.attribute_commit_message = attribute_commit_message
|
|
self.commit_prompt = commit_prompt
|
|
self.subtree_only = subtree_only
|
|
self.ignore_file_cache = {}
|
|
|
|
if git_dname:
|
|
check_fnames = [git_dname]
|
|
elif fnames:
|
|
check_fnames = fnames
|
|
else:
|
|
check_fnames = ["."]
|
|
|
|
repo_paths = []
|
|
for fname in check_fnames:
|
|
fname = Path(fname)
|
|
fname = fname.resolve()
|
|
|
|
if not fname.exists() and fname.parent.exists():
|
|
fname = fname.parent
|
|
|
|
try:
|
|
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
|
|
repo_path = utils.safe_abs_path(repo_path)
|
|
repo_paths.append(repo_path)
|
|
except git.exc.InvalidGitRepositoryError:
|
|
pass
|
|
except git.exc.NoSuchPathError:
|
|
pass
|
|
|
|
num_repos = len(set(repo_paths))
|
|
|
|
if num_repos == 0:
|
|
raise FileNotFoundError
|
|
if num_repos > 1:
|
|
self.io.tool_error("Files are in different git repos.")
|
|
raise FileNotFoundError
|
|
|
|
# https://github.com/gitpython-developers/GitPython/issues/427
|
|
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
|
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
|
|
|
if aider_ignore_file:
|
|
self.aider_ignore_file = Path(aider_ignore_file)
|
|
|
|
def commit(self, fnames=None, context=None, message=None, aider_edits=False):
|
|
if not fnames and not self.repo.is_dirty():
|
|
return
|
|
|
|
diffs = self.get_diffs(fnames)
|
|
if not diffs:
|
|
return
|
|
|
|
if message:
|
|
commit_message = message
|
|
else:
|
|
commit_message = self.get_commit_message(diffs, context)
|
|
|
|
if aider_edits and self.attribute_commit_message:
|
|
commit_message = "aider: " + commit_message
|
|
|
|
if not commit_message:
|
|
commit_message = "(no commit message provided)"
|
|
|
|
full_commit_message = commit_message
|
|
# if context:
|
|
# full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
|
|
|
|
cmd = ["-m", full_commit_message, "--no-verify"]
|
|
if fnames:
|
|
fnames = [str(self.abs_root_path(fn)) for fn in fnames]
|
|
for fname in fnames:
|
|
self.repo.git.add(fname)
|
|
cmd += ["--"] + fnames
|
|
else:
|
|
cmd += ["-a"]
|
|
|
|
original_user_name = self.repo.config_reader().get_value("user", "name")
|
|
original_committer_name_env = os.environ.get("GIT_COMMITTER_NAME")
|
|
committer_name = f"{original_user_name} (aider)"
|
|
|
|
if self.attribute_committer:
|
|
os.environ["GIT_COMMITTER_NAME"] = committer_name
|
|
|
|
if aider_edits and self.attribute_author:
|
|
original_auther_name_env = os.environ.get("GIT_AUTHOR_NAME")
|
|
os.environ["GIT_AUTHOR_NAME"] = committer_name
|
|
|
|
self.repo.git.commit(cmd)
|
|
commit_hash = self.repo.head.commit.hexsha[:7]
|
|
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
|
|
|
# Restore the env
|
|
|
|
if self.attribute_committer:
|
|
if original_committer_name_env is not None:
|
|
os.environ["GIT_COMMITTER_NAME"] = original_committer_name_env
|
|
else:
|
|
del os.environ["GIT_COMMITTER_NAME"]
|
|
|
|
if aider_edits and self.attribute_author:
|
|
if original_auther_name_env is not None:
|
|
os.environ["GIT_AUTHOR_NAME"] = original_auther_name_env
|
|
else:
|
|
del os.environ["GIT_AUTHOR_NAME"]
|
|
|
|
return commit_hash, commit_message
|
|
|
|
def get_rel_repo_dir(self):
|
|
try:
|
|
return os.path.relpath(self.repo.git_dir, os.getcwd())
|
|
except ValueError:
|
|
return self.repo.git_dir
|
|
|
|
def get_commit_message(self, diffs, context):
|
|
if len(diffs) >= 4 * 1024 * 4:
|
|
self.io.tool_error("Diff is too large to generate a commit message.")
|
|
return
|
|
|
|
diffs = "# Diffs:\n" + diffs
|
|
|
|
content = ""
|
|
if context:
|
|
content += context + "\n"
|
|
content += diffs
|
|
|
|
system_content = self.commit_prompt or prompts.commit_system
|
|
messages = [
|
|
dict(role="system", content=system_content),
|
|
dict(role="user", content=content),
|
|
]
|
|
|
|
for model in self.models:
|
|
commit_message = send_with_retries(model.name, messages)
|
|
if commit_message:
|
|
break
|
|
|
|
if not commit_message:
|
|
self.io.tool_error("Failed to generate commit message!")
|
|
return
|
|
|
|
commit_message = commit_message.strip()
|
|
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
|
commit_message = commit_message[1:-1].strip()
|
|
|
|
return commit_message
|
|
|
|
def get_diffs(self, fnames=None):
|
|
# We always want diffs of index and working dir
|
|
|
|
current_branch_has_commits = False
|
|
try:
|
|
active_branch = self.repo.active_branch
|
|
try:
|
|
commits = self.repo.iter_commits(active_branch)
|
|
current_branch_has_commits = any(commits)
|
|
except git.exc.GitCommandError:
|
|
pass
|
|
except TypeError:
|
|
pass
|
|
|
|
if not fnames:
|
|
fnames = []
|
|
|
|
diffs = ""
|
|
for fname in fnames:
|
|
if not self.path_in_repo(fname):
|
|
diffs += f"Added {fname}\n"
|
|
|
|
if current_branch_has_commits:
|
|
args = ["HEAD", "--"] + list(fnames)
|
|
diffs += self.repo.git.diff(*args)
|
|
return diffs
|
|
|
|
wd_args = ["--"] + list(fnames)
|
|
index_args = ["--cached"] + wd_args
|
|
|
|
diffs += self.repo.git.diff(*index_args)
|
|
diffs += self.repo.git.diff(*wd_args)
|
|
|
|
return diffs
|
|
|
|
def diff_commits(self, pretty, from_commit, to_commit):
|
|
args = []
|
|
if pretty:
|
|
args += ["--color"]
|
|
|
|
args += [from_commit, to_commit]
|
|
diffs = self.repo.git.diff(*args)
|
|
|
|
return diffs
|
|
|
|
def get_tracked_files(self):
|
|
if not self.repo:
|
|
return []
|
|
|
|
try:
|
|
commit = self.repo.head.commit
|
|
except ValueError:
|
|
commit = None
|
|
|
|
files = []
|
|
if commit:
|
|
for blob in commit.tree.traverse():
|
|
if blob.type == "blob": # blob is a file
|
|
files.append(blob.path)
|
|
|
|
# Add staged files
|
|
index = self.repo.index
|
|
staged_files = [path for path, _ in index.entries.keys()]
|
|
|
|
files.extend(staged_files)
|
|
|
|
# convert to appropriate os.sep, since git always normalizes to /
|
|
res = set(self.normalize_path(path) for path in files)
|
|
|
|
res = [fname for fname in res if not self.ignored_file(fname)]
|
|
|
|
return res
|
|
|
|
def normalize_path(self, path):
|
|
return str(Path(PurePosixPath((Path(self.root) / path).relative_to(self.root))))
|
|
|
|
def refresh_aider_ignore(self):
|
|
if not self.aider_ignore_file:
|
|
return
|
|
|
|
current_time = time.time()
|
|
if current_time - self.aider_ignore_last_check < 1:
|
|
return
|
|
|
|
self.aider_ignore_last_check = current_time
|
|
|
|
if not self.aider_ignore_file.is_file():
|
|
return
|
|
|
|
mtime = self.aider_ignore_file.stat().st_mtime
|
|
if mtime != self.aider_ignore_ts:
|
|
self.aider_ignore_ts = mtime
|
|
self.ignore_file_cache = {}
|
|
lines = self.aider_ignore_file.read_text().splitlines()
|
|
self.aider_ignore_spec = pathspec.PathSpec.from_lines(
|
|
pathspec.patterns.GitWildMatchPattern,
|
|
lines,
|
|
)
|
|
|
|
def ignored_file(self, fname):
|
|
self.refresh_aider_ignore()
|
|
|
|
if fname in self.ignore_file_cache:
|
|
return self.ignore_file_cache[fname]
|
|
|
|
result = self.ignored_file_raw(fname)
|
|
self.ignore_file_cache[fname] = result
|
|
return result
|
|
|
|
def ignored_file_raw(self, fname):
|
|
if self.subtree_only:
|
|
fname_path = Path(self.normalize_path(fname))
|
|
cwd_path = Path(self.normalize_path(Path.cwd().relative_to(self.root)))
|
|
|
|
if cwd_path not in fname_path.parents:
|
|
return True
|
|
|
|
if not self.aider_ignore_file or not self.aider_ignore_file.is_file():
|
|
return False
|
|
|
|
try:
|
|
fname = self.normalize_path(fname)
|
|
except ValueError:
|
|
return True
|
|
|
|
return self.aider_ignore_spec.match_file(fname)
|
|
|
|
def path_in_repo(self, path):
|
|
if not self.repo:
|
|
return
|
|
|
|
tracked_files = set(self.get_tracked_files())
|
|
return self.normalize_path(path) in tracked_files
|
|
|
|
def abs_root_path(self, path):
|
|
res = Path(self.root) / path
|
|
return utils.safe_abs_path(res)
|
|
|
|
def get_dirty_files(self):
|
|
"""
|
|
Returns a list of all files which are dirty (not committed), either staged or in the working
|
|
directory.
|
|
"""
|
|
dirty_files = set()
|
|
|
|
# Get staged files
|
|
staged_files = self.repo.git.diff("--name-only", "--cached").splitlines()
|
|
dirty_files.update(staged_files)
|
|
|
|
# Get unstaged files
|
|
unstaged_files = self.repo.git.diff("--name-only").splitlines()
|
|
dirty_files.update(unstaged_files)
|
|
|
|
return list(dirty_files)
|
|
|
|
def is_dirty(self, path=None):
|
|
if path and not self.path_in_repo(path):
|
|
return True
|
|
|
|
return self.repo.is_dirty(path=path)
|