aider/aider/repo.py

191 lines
5.6 KiB
Python

import os
from pathlib import Path, PurePosixPath
import git
from aider import models, prompts, utils
from aider.sendchat import simple_send_with_retries
from .dump import dump # noqa: F401
class GitRepo:
repo = None
def __init__(self, io, fnames, git_dname):
self.io = io
if git_dname:
check_fnames = [git_dname]
elif fnames:
check_fnames = fnames
else:
check_fnames = ["."]
repo_paths = []
for fname in check_fnames:
fname = Path(fname)
fname = fname.resolve()
if not fname.exists() and fname.parent.exists():
fname = fname.parent
try:
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
repo_path = utils.safe_abs_path(repo_path)
repo_paths.append(repo_path)
except git.exc.InvalidGitRepositoryError:
pass
num_repos = len(set(repo_paths))
if num_repos == 0:
raise FileNotFoundError
if num_repos > 1:
self.io.tool_error("Files are in different git repos.")
raise FileNotFoundError
# https://github.com/gitpython-developers/GitPython/issues/427
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
def add_new_files(self, fnames):
cur_files = [Path(fn).resolve() for fn in self.get_tracked_files()]
for fname in fnames:
if Path(fname).resolve() in cur_files:
continue
if not Path(fname).exists():
continue
self.io.tool_output(f"Adding {fname} to git")
self.repo.git.add(fname)
def commit(self, context=None, prefix=None, message=None):
if not self.repo.is_dirty():
return
if message:
commit_message = message
else:
diffs = self.get_diffs(False)
commit_message = self.get_commit_message(diffs, context)
if not commit_message:
commit_message = "(no commit message provided)"
if prefix:
commit_message = prefix + commit_message
full_commit_message = commit_message
if context:
full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify")
commit_hash = self.repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
return commit_hash, commit_message
def get_rel_repo_dir(self):
try:
return os.path.relpath(self.repo.git_dir, os.getcwd())
except ValueError:
return self.repo.git_dir
def get_commit_message(self, diffs, context):
if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error(
f"Diff is too large for {models.GPT35.name} to generate a commit message."
)
return
diffs = "# Diffs:\n" + diffs
content = ""
if context:
content += context + "\n"
content += diffs
messages = [
dict(role="system", content=prompts.commit_system),
dict(role="user", content=content),
]
for model in [models.GPT35.name, models.GPT35_16k.name]:
commit_message = simple_send_with_retries(model, messages)
if commit_message:
break
if not commit_message:
self.io.tool_error("Failed to generate commit message!")
return
commit_message = commit_message.strip()
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
commit_message = commit_message[1:-1].strip()
return commit_message
def get_diffs(self, pretty, *args):
args = list(args)
# if args are specified, just add --pretty if needed
if args:
if pretty:
args = ["--color"] + args
return self.repo.git.diff(*args)
# otherwise, we always want diffs of index and working dir
try:
commits = self.repo.iter_commits(self.repo.active_branch)
current_branch_has_commits = any(commits)
except git.exc.GitCommandError:
current_branch_has_commits = False
if pretty:
args = ["--color"]
if current_branch_has_commits:
# if there is a HEAD, just diff against it to pick up index + working
args += ["HEAD"]
return self.repo.git.diff(*args)
# diffs in the index
diffs = self.repo.git.diff(*(args + ["--cached"]))
# plus, diffs in the working dir
diffs += self.repo.git.diff(*args)
return diffs
def show_diffs(self, pretty):
diffs = self.get_diffs(pretty)
print(diffs)
def get_tracked_files(self):
if not self.repo:
return []
try:
commit = self.repo.head.commit
except ValueError:
commit = None
files = []
if commit:
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.append(blob.path)
# Add staged files
index = self.repo.index
staged_files = [path for path, _ in index.entries.keys()]
files.extend(staged_files)
# convert to appropriate os.sep, since git always normalizes to /
res = set(str(Path(PurePosixPath(path))) for path in files)
return res
def is_dirty(self):
return self.repo.is_dirty()