mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-24 21:45:00 +00:00

The GitRepo class in the aider/repo.py file has been updated to support a new `commit_prompt` argument. This allows overriding the default `prompts.commit_system` when generating commit messages. The changes include: 1. Adding the `commit_prompt` parameter to the `__init__` method of the `GitRepo` class. 2. Storing the `commit_prompt` value in the `self.commit_prompt` attribute. 3. Modifying the `get_commit_message` method to use the `self.commit_prompt` value if it's provided, otherwise falling back to the default `prompts.commit_system`. This change provides more flexibility in customizing the commit message generation process, allowing users to provide their own custom prompts if needed.
308 lines
9.2 KiB
Python
308 lines
9.2 KiB
Python
import os
|
|
from pathlib import Path, PurePosixPath
|
|
|
|
import git
|
|
import pathspec
|
|
|
|
from aider import prompts, utils
|
|
from aider.sendchat import simple_send_with_retries
|
|
|
|
from .dump import dump # noqa: F401
|
|
|
|
|
|
class GitRepo:
|
|
repo = None
|
|
aider_ignore_file = None
|
|
aider_ignore_spec = None
|
|
aider_ignore_ts = 0
|
|
|
|
def __init__(
|
|
self,
|
|
io,
|
|
fnames,
|
|
git_dname,
|
|
aider_ignore_file=None,
|
|
models=None,
|
|
attribute_author=True,
|
|
attribute_committer=True,
|
|
attribute_commit_message=False,
|
|
commit_prompt=None,
|
|
):
|
|
self.io = io
|
|
self.models = models
|
|
|
|
self.attribute_author = attribute_author
|
|
self.attribute_committer = attribute_committer
|
|
self.attribute_commit_message = attribute_commit_message
|
|
self.commit_prompt = commit_prompt
|
|
|
|
if git_dname:
|
|
check_fnames = [git_dname]
|
|
elif fnames:
|
|
check_fnames = fnames
|
|
else:
|
|
check_fnames = ["."]
|
|
|
|
repo_paths = []
|
|
for fname in check_fnames:
|
|
fname = Path(fname)
|
|
fname = fname.resolve()
|
|
|
|
if not fname.exists() and fname.parent.exists():
|
|
fname = fname.parent
|
|
|
|
try:
|
|
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
|
|
repo_path = utils.safe_abs_path(repo_path)
|
|
repo_paths.append(repo_path)
|
|
except git.exc.InvalidGitRepositoryError:
|
|
pass
|
|
except git.exc.NoSuchPathError:
|
|
pass
|
|
|
|
num_repos = len(set(repo_paths))
|
|
|
|
if num_repos == 0:
|
|
raise FileNotFoundError
|
|
if num_repos > 1:
|
|
self.io.tool_error("Files are in different git repos.")
|
|
raise FileNotFoundError
|
|
|
|
# https://github.com/gitpython-developers/GitPython/issues/427
|
|
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
|
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
|
|
|
if aider_ignore_file:
|
|
self.aider_ignore_file = Path(aider_ignore_file)
|
|
|
|
def commit(self, fnames=None, context=None, message=None, aider_edits=False):
|
|
if not fnames and not self.repo.is_dirty():
|
|
return
|
|
|
|
diffs = self.get_diffs(fnames)
|
|
if not diffs:
|
|
return
|
|
|
|
if message:
|
|
commit_message = message
|
|
else:
|
|
commit_message = self.get_commit_message(diffs, context)
|
|
|
|
if aider_edits and self.attribute_commit_message:
|
|
commit_message = "aider: " + commit_message
|
|
|
|
if not commit_message:
|
|
commit_message = "(no commit message provided)"
|
|
|
|
full_commit_message = commit_message
|
|
# if context:
|
|
# full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
|
|
|
|
cmd = ["-m", full_commit_message, "--no-verify"]
|
|
if fnames:
|
|
fnames = [str(self.abs_root_path(fn)) for fn in fnames]
|
|
for fname in fnames:
|
|
self.repo.git.add(fname)
|
|
cmd += ["--"] + fnames
|
|
else:
|
|
cmd += ["-a"]
|
|
|
|
original_user_name = self.repo.config_reader().get_value("user", "name")
|
|
original_committer_name_env = os.environ.get("GIT_COMMITTER_NAME")
|
|
committer_name = f"{original_user_name} (aider)"
|
|
|
|
if self.attribute_committer:
|
|
os.environ["GIT_COMMITTER_NAME"] = committer_name
|
|
|
|
if aider_edits and self.attribute_author:
|
|
original_auther_name_env = os.environ.get("GIT_AUTHOR_NAME")
|
|
os.environ["GIT_AUTHOR_NAME"] = committer_name
|
|
|
|
self.repo.git.commit(cmd)
|
|
commit_hash = self.repo.head.commit.hexsha[:7]
|
|
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
|
|
|
# Restore the env
|
|
|
|
if self.attribute_committer:
|
|
if original_committer_name_env is not None:
|
|
os.environ["GIT_COMMITTER_NAME"] = original_committer_name_env
|
|
else:
|
|
del os.environ["GIT_COMMITTER_NAME"]
|
|
|
|
if aider_edits and self.attribute_author:
|
|
if original_auther_name_env is not None:
|
|
os.environ["GIT_AUTHOR_NAME"] = original_auther_name_env
|
|
else:
|
|
del os.environ["GIT_AUTHOR_NAME"]
|
|
|
|
return commit_hash, commit_message
|
|
|
|
def get_rel_repo_dir(self):
|
|
try:
|
|
return os.path.relpath(self.repo.git_dir, os.getcwd())
|
|
except ValueError:
|
|
return self.repo.git_dir
|
|
|
|
def get_commit_message(self, diffs, context):
|
|
if len(diffs) >= 4 * 1024 * 4:
|
|
self.io.tool_error("Diff is too large to generate a commit message.")
|
|
return
|
|
|
|
diffs = "# Diffs:\n" + diffs
|
|
|
|
content = ""
|
|
if context:
|
|
content += context + "\n"
|
|
content += diffs
|
|
|
|
system_content = self.commit_prompt or prompts.commit_system
|
|
messages = [
|
|
dict(role="system", content=system_content),
|
|
dict(role="user", content=content),
|
|
]
|
|
|
|
for model in self.models:
|
|
commit_message = simple_send_with_retries(model.name, messages)
|
|
if commit_message:
|
|
break
|
|
|
|
if not commit_message:
|
|
self.io.tool_error("Failed to generate commit message!")
|
|
return
|
|
|
|
commit_message = commit_message.strip()
|
|
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
|
commit_message = commit_message[1:-1].strip()
|
|
|
|
return commit_message
|
|
|
|
def get_diffs(self, fnames=None):
|
|
# We always want diffs of index and working dir
|
|
|
|
current_branch_has_commits = False
|
|
try:
|
|
active_branch = self.repo.active_branch
|
|
try:
|
|
commits = self.repo.iter_commits(active_branch)
|
|
current_branch_has_commits = any(commits)
|
|
except git.exc.GitCommandError:
|
|
pass
|
|
except TypeError:
|
|
pass
|
|
|
|
if not fnames:
|
|
fnames = []
|
|
|
|
diffs = ""
|
|
for fname in fnames:
|
|
if not self.path_in_repo(fname):
|
|
diffs += f"Added {fname}\n"
|
|
|
|
if current_branch_has_commits:
|
|
args = ["HEAD", "--"] + list(fnames)
|
|
diffs += self.repo.git.diff(*args)
|
|
return diffs
|
|
|
|
wd_args = ["--"] + list(fnames)
|
|
index_args = ["--cached"] + wd_args
|
|
|
|
diffs += self.repo.git.diff(*index_args)
|
|
diffs += self.repo.git.diff(*wd_args)
|
|
|
|
return diffs
|
|
|
|
def diff_commits(self, pretty, from_commit, to_commit):
|
|
args = []
|
|
if pretty:
|
|
args += ["--color"]
|
|
|
|
args += [from_commit, to_commit]
|
|
diffs = self.repo.git.diff(*args)
|
|
|
|
return diffs
|
|
|
|
def get_tracked_files(self):
|
|
if not self.repo:
|
|
return []
|
|
|
|
try:
|
|
commit = self.repo.head.commit
|
|
except ValueError:
|
|
commit = None
|
|
|
|
files = []
|
|
if commit:
|
|
for blob in commit.tree.traverse():
|
|
if blob.type == "blob": # blob is a file
|
|
files.append(blob.path)
|
|
|
|
# Add staged files
|
|
index = self.repo.index
|
|
staged_files = [path for path, _ in index.entries.keys()]
|
|
|
|
files.extend(staged_files)
|
|
|
|
# convert to appropriate os.sep, since git always normalizes to /
|
|
res = set(self.normalize_path(path) for path in files)
|
|
|
|
res = [fname for fname in res if not self.ignored_file(fname)]
|
|
|
|
return res
|
|
|
|
def normalize_path(self, path):
|
|
return str(Path(PurePosixPath((Path(self.root) / path).relative_to(self.root))))
|
|
|
|
def ignored_file(self, fname):
|
|
if not self.aider_ignore_file or not self.aider_ignore_file.is_file():
|
|
return
|
|
|
|
try:
|
|
fname = self.normalize_path(fname)
|
|
except ValueError:
|
|
return
|
|
|
|
mtime = self.aider_ignore_file.stat().st_mtime
|
|
if mtime != self.aider_ignore_ts:
|
|
self.aider_ignore_ts = mtime
|
|
lines = self.aider_ignore_file.read_text().splitlines()
|
|
self.aider_ignore_spec = pathspec.PathSpec.from_lines(
|
|
pathspec.patterns.GitWildMatchPattern,
|
|
lines,
|
|
)
|
|
|
|
return self.aider_ignore_spec.match_file(fname)
|
|
|
|
def path_in_repo(self, path):
|
|
if not self.repo:
|
|
return
|
|
|
|
tracked_files = set(self.get_tracked_files())
|
|
return self.normalize_path(path) in tracked_files
|
|
|
|
def abs_root_path(self, path):
|
|
res = Path(self.root) / path
|
|
return utils.safe_abs_path(res)
|
|
|
|
def get_dirty_files(self):
|
|
"""
|
|
Returns a list of all files which are dirty (not committed), either staged or in the working
|
|
directory.
|
|
"""
|
|
dirty_files = set()
|
|
|
|
# Get staged files
|
|
staged_files = self.repo.git.diff("--name-only", "--cached").splitlines()
|
|
dirty_files.update(staged_files)
|
|
|
|
# Get unstaged files
|
|
unstaged_files = self.repo.git.diff("--name-only").splitlines()
|
|
dirty_files.update(unstaged_files)
|
|
|
|
return list(dirty_files)
|
|
|
|
def is_dirty(self, path=None):
|
|
if path and not self.path_in_repo(path):
|
|
return True
|
|
|
|
return self.repo.is_dirty(path=path)
|