mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-22 13:25:00 +00:00
191 lines
5.6 KiB
Python
191 lines
5.6 KiB
Python
import os
|
|
from pathlib import Path, PurePosixPath
|
|
|
|
import git
|
|
|
|
from aider import models, prompts, utils
|
|
from aider.sendchat import simple_send_with_retries
|
|
|
|
from .dump import dump # noqa: F401
|
|
|
|
|
|
class GitRepo:
|
|
repo = None
|
|
|
|
def __init__(self, io, fnames, git_dname):
|
|
self.io = io
|
|
|
|
if git_dname:
|
|
check_fnames = [git_dname]
|
|
elif fnames:
|
|
check_fnames = fnames
|
|
else:
|
|
check_fnames = ["."]
|
|
|
|
repo_paths = []
|
|
for fname in check_fnames:
|
|
fname = Path(fname)
|
|
fname = fname.resolve()
|
|
|
|
if not fname.exists() and fname.parent.exists():
|
|
fname = fname.parent
|
|
|
|
try:
|
|
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
|
|
repo_path = utils.safe_abs_path(repo_path)
|
|
repo_paths.append(repo_path)
|
|
except git.exc.InvalidGitRepositoryError:
|
|
pass
|
|
|
|
num_repos = len(set(repo_paths))
|
|
|
|
if num_repos == 0:
|
|
raise FileNotFoundError
|
|
if num_repos > 1:
|
|
self.io.tool_error("Files are in different git repos.")
|
|
raise FileNotFoundError
|
|
|
|
# https://github.com/gitpython-developers/GitPython/issues/427
|
|
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
|
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
|
|
|
def add_new_files(self, fnames):
|
|
cur_files = [Path(fn).resolve() for fn in self.get_tracked_files()]
|
|
for fname in fnames:
|
|
if Path(fname).resolve() in cur_files:
|
|
continue
|
|
if not Path(fname).exists():
|
|
continue
|
|
self.io.tool_output(f"Adding {fname} to git")
|
|
self.repo.git.add(fname)
|
|
|
|
def commit(self, context=None, prefix=None, message=None):
|
|
if not self.repo.is_dirty():
|
|
return
|
|
|
|
if message:
|
|
commit_message = message
|
|
else:
|
|
diffs = self.get_diffs(False)
|
|
commit_message = self.get_commit_message(diffs, context)
|
|
|
|
if not commit_message:
|
|
commit_message = "(no commit message provided)"
|
|
|
|
if prefix:
|
|
commit_message = prefix + commit_message
|
|
|
|
full_commit_message = commit_message
|
|
if context:
|
|
full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
|
|
|
|
self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify")
|
|
commit_hash = self.repo.head.commit.hexsha[:7]
|
|
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
|
|
|
return commit_hash, commit_message
|
|
|
|
def get_rel_repo_dir(self):
|
|
try:
|
|
return os.path.relpath(self.repo.git_dir, os.getcwd())
|
|
except ValueError:
|
|
return self.repo.git_dir
|
|
|
|
def get_commit_message(self, diffs, context):
|
|
if len(diffs) >= 4 * 1024 * 4:
|
|
self.io.tool_error(
|
|
f"Diff is too large for {models.GPT35.name} to generate a commit message."
|
|
)
|
|
return
|
|
|
|
diffs = "# Diffs:\n" + diffs
|
|
|
|
content = ""
|
|
if context:
|
|
content += context + "\n"
|
|
content += diffs
|
|
|
|
messages = [
|
|
dict(role="system", content=prompts.commit_system),
|
|
dict(role="user", content=content),
|
|
]
|
|
|
|
for model in [models.GPT35.name, models.GPT35_16k.name]:
|
|
commit_message = simple_send_with_retries(model, messages)
|
|
if commit_message:
|
|
break
|
|
|
|
if not commit_message:
|
|
self.io.tool_error("Failed to generate commit message!")
|
|
return
|
|
|
|
commit_message = commit_message.strip()
|
|
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
|
commit_message = commit_message[1:-1].strip()
|
|
|
|
return commit_message
|
|
|
|
def get_diffs(self, pretty, *args):
|
|
args = list(args)
|
|
|
|
# if args are specified, just add --pretty if needed
|
|
if args:
|
|
if pretty:
|
|
args = ["--color"] + args
|
|
return self.repo.git.diff(*args)
|
|
|
|
# otherwise, we always want diffs of index and working dir
|
|
|
|
try:
|
|
commits = self.repo.iter_commits(self.repo.active_branch)
|
|
current_branch_has_commits = any(commits)
|
|
except git.exc.GitCommandError:
|
|
current_branch_has_commits = False
|
|
|
|
if pretty:
|
|
args = ["--color"]
|
|
|
|
if current_branch_has_commits:
|
|
# if there is a HEAD, just diff against it to pick up index + working
|
|
args += ["HEAD"]
|
|
return self.repo.git.diff(*args)
|
|
|
|
# diffs in the index
|
|
diffs = self.repo.git.diff(*(args + ["--cached"]))
|
|
# plus, diffs in the working dir
|
|
diffs += self.repo.git.diff(*args)
|
|
|
|
return diffs
|
|
|
|
def show_diffs(self, pretty):
|
|
diffs = self.get_diffs(pretty)
|
|
print(diffs)
|
|
|
|
def get_tracked_files(self):
|
|
if not self.repo:
|
|
return []
|
|
|
|
try:
|
|
commit = self.repo.head.commit
|
|
except ValueError:
|
|
commit = None
|
|
|
|
files = []
|
|
if commit:
|
|
for blob in commit.tree.traverse():
|
|
if blob.type == "blob": # blob is a file
|
|
files.append(blob.path)
|
|
|
|
# Add staged files
|
|
index = self.repo.index
|
|
staged_files = [path for path, _ in index.entries.keys()]
|
|
|
|
files.extend(staged_files)
|
|
|
|
# convert to appropriate os.sep, since git always normalizes to /
|
|
res = set(str(Path(PurePosixPath(path))) for path in files)
|
|
|
|
return res
|
|
|
|
def is_dirty(self):
|
|
return self.repo.is_dirty()
|