This commit is contained in:
Paul Gauthier 2023-07-21 12:57:01 -03:00
parent 296e7614c4
commit 14b24dc2fd
3 changed files with 74 additions and 70 deletions

View file

@ -148,14 +148,21 @@ class Coder:
self.commands = Commands(self.io, self) self.commands = Commands(self.io, self)
for fname in fnames:
fname = Path(fname)
if not fname.exists():
self.io.tool_output(f"Creating empty file {fname}")
fname.parent.mkdir(parents=True, exist_ok=True)
fname.touch()
self.abs_fnames.add(str(fname.resolve()))
if use_git: if use_git:
try: try:
self.repo = AiderRepo(fnames) self.repo = AiderRepo(self.io, fnames)
self.root = self.repo.root self.root = self.repo.root
except FileNotFoundError: except FileNotFoundError:
self.repo = None self.repo = None
else:
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames])
if self.repo: if self.repo:
rel_repo_dir = self.repo.get_rel_repo_dir() rel_repo_dir = self.repo.get_rel_repo_dir()
@ -189,6 +196,8 @@ class Coder:
for fname in self.get_inchat_relative_files(): for fname in self.get_inchat_relative_files():
self.io.tool_output(f"Added {fname} to the chat.") self.io.tool_output(f"Added {fname} to the chat.")
self.repo.add_new_files(fnames)
# validate the functions jsonschema # validate the functions jsonschema
if self.functions: if self.functions:
for function in self.functions: for function in self.functions:
@ -351,12 +360,6 @@ class Coder:
if cmd in "add clear commit diff drop exit help ls tokens".split(): if cmd in "add clear commit diff drop exit help ls tokens".split():
return return
if not self.dirty_commits:
return
if not self.repo:
return
if not self.repo.is_dirty():
return
if self.last_asked_for_commit_time >= self.get_last_modified(): if self.last_asked_for_commit_time >= self.get_last_modified():
return return
return True return True
@ -395,6 +398,13 @@ class Coder:
return self.send_new_user_message(inp) return self.send_new_user_message(inp)
def dirty_commit(self): def dirty_commit(self):
if not self.dirty_commits:
return
if not self.repo:
return
if not self.repo.is_dirty():
return
self.io.tool_output("Git repo has uncommitted changes.") self.io.tool_output("Git repo has uncommitted changes.")
self.repo.show_diffs(self.pretty) self.repo.show_diffs(self.pretty)
self.last_asked_for_commit_time = self.get_last_modified() self.last_asked_for_commit_time = self.get_last_modified()
@ -410,7 +420,7 @@ class Coder:
else: else:
message = res.strip() message = res.strip()
self.commit(message=message) self.repo.commit(message=message)
# files changed, move cur messages back behind the files messages # files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
@ -524,7 +534,7 @@ class Coder:
def auto_commit(self): def auto_commit(self):
context = self.get_context_from_history(self.cur_messages) context = self.get_context_from_history(self.cur_messages)
res = self.commit(context=context, prefix="aider: ") res = self.repo.commit(context=context, prefix="aider: ")
if res: if res:
commit_hash, commit_message = res commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash self.last_aider_commit_hash = commit_hash

View file

@ -440,8 +440,7 @@ def main(args=None, input=None, output=None):
io.tool_output(repo_map) io.tool_output(repo_map)
return return
if args.dirty_commits: coder.dirty_commit()
coder.commit(ask=True, which="repo_files")
if args.apply: if args.apply:
content = io.read_text(args.apply) content = io.read_text(args.apply)

View file

@ -7,24 +7,23 @@ import openai
from aider import models, prompts, utils from aider import models, prompts, utils
from aider.sendchat import send_with_retries from aider.sendchat import send_with_retries
from .dump import dump # noqa: F401
class AiderRepo: class AiderRepo:
repo = None repo = None
def __init__(self, io, cmd_line_fnames): def __init__(self, io, fnames):
self.io = io self.io = io
if not cmd_line_fnames: if fnames:
cmd_line_fnames = ["."] check_fnames = fnames
else:
check_fnames = ["."]
repo_paths = [] repo_paths = []
for fname in cmd_line_fnames: for fname in check_fnames:
fname = Path(fname) fname = Path(fname)
if not fname.exists():
self.io.tool_output(f"Creating empty file {fname}")
fname.parent.mkdir(parents=True, exist_ok=True)
fname.touch()
fname = fname.resolve() fname = fname.resolve()
try: try:
@ -34,9 +33,6 @@ class AiderRepo:
except git.exc.InvalidGitRepositoryError: except git.exc.InvalidGitRepositoryError:
pass pass
if fname.is_dir():
continue
num_repos = len(set(repo_paths)) num_repos = len(set(repo_paths))
if num_repos == 0: if num_repos == 0:
@ -49,36 +45,13 @@ class AiderRepo:
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB) self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
self.root = utils.safe_abs_path(self.repo.working_tree_dir) self.root = utils.safe_abs_path(self.repo.working_tree_dir)
def ___(self, fnames): def add_new_files(self, fnames):
# TODO! cur_files = [Path(fn).resolve() for fn in self.get_tracked_files()]
self.abs_fnames.add(str(fname))
new_files = []
for fname in fnames: for fname in fnames:
relative_fname = self.get_rel_fname(fname) if Path(fname).resolve() in cur_files:
continue
tracked_files = set(self.get_tracked_files()) self.io.tool_output(f"Adding {fname} to git")
if relative_fname not in tracked_files: self.repo.git.add(fname)
new_files.append(relative_fname)
if new_files:
rel_repo_dir = self.get_rel_repo_dir()
self.io.tool_output(f"Files not tracked in {rel_repo_dir}:")
for fn in new_files:
self.io.tool_output(f" - {fn}")
if self.io.confirm_ask("Add them?"):
for relative_fname in new_files:
self.repo.git.add(relative_fname)
self.io.tool_output(f"Added {relative_fname} to the git repo")
show_files = ", ".join(new_files)
commit_message = f"Added new files to the git repo: {show_files}"
self.repo.git.commit("-m", commit_message, "--no-verify")
commit_hash = self.repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
else:
self.io.tool_error("Skipped adding new files to the git repo.")
def commit(self, context=None, prefix=None, message=None): def commit(self, context=None, prefix=None, message=None):
if not self.repo.is_dirty(): if not self.repo.is_dirty():
@ -88,6 +61,7 @@ class AiderRepo:
commit_message = message commit_message = message
else: else:
diffs = self.get_diffs(False) diffs = self.get_diffs(False)
dump(diffs)
commit_message = self.get_commit_message(diffs, context) commit_message = self.get_commit_message(diffs, context)
if not commit_message: if not commit_message:
@ -96,10 +70,11 @@ class AiderRepo:
if prefix: if prefix:
commit_message = prefix + commit_message commit_message = prefix + commit_message
full_commit_message = commit_message
if context: if context:
commit_message = commit_message + "\n\n# Aider chat conversation:\n\n" + context full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
self.repo.git.commit("-a", "-m", commit_message, "--no-verify") self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify")
commit_hash = self.repo.head.commit.hexsha[:7] commit_hash = self.repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}") self.io.tool_output(f"Commit {commit_hash} {commit_message}")
@ -120,21 +95,34 @@ class AiderRepo:
diffs = "# Diffs:\n" + diffs diffs = "# Diffs:\n" + diffs
content = ""
if context:
content += context + "\n"
content += diffs
dump(content)
messages = [ messages = [
dict(role="system", content=prompts.commit_system), dict(role="system", content=prompts.commit_system),
dict(role="user", content=context + diffs), dict(role="user", content=content),
] ]
try: commit_message = None
_hash, response = send_with_retries( for model in [models.GPT35.name, models.GPT35_16k.name]:
model=models.GPT35.name, try:
messages=messages, _hash, response = send_with_retries(
functions=None, model=models.GPT35.name,
stream=False, messages=messages,
) functions=None,
commit_message = response.choices[0].message.content stream=False,
except (AttributeError, openai.error.InvalidRequestError): )
self.io.tool_error(f"Failed to generate commit message using {models.GPT35.name}") commit_message = response.choices[0].message.content
break
except (AttributeError, openai.error.InvalidRequestError):
pass
if not commit_message:
self.io.tool_error("Failed to generate commit message!")
return return
commit_message = commit_message.strip() commit_message = commit_message.strip()
@ -146,6 +134,8 @@ class AiderRepo:
def get_diffs(self, pretty, *args): def get_diffs(self, pretty, *args):
if pretty: if pretty:
args = ["--color"] + list(args) args = ["--color"] + list(args)
if not args:
args = ["HEAD"]
diffs = self.repo.git.diff(*args) diffs = self.repo.git.diff(*args)
return diffs return diffs
@ -156,10 +146,12 @@ class AiderRepo:
except git.exc.GitCommandError: except git.exc.GitCommandError:
current_branch_has_commits = False current_branch_has_commits = False
if not current_branch_has_commits: dump(current_branch_has_commits)
return
diffs = self.get_diffs(pretty, "HEAD") if not current_branch_has_commits:
return ""
diffs = self.get_diffs(pretty)
print(diffs) print(diffs)
def get_tracked_files(self): def get_tracked_files(self):
@ -180,3 +172,6 @@ class AiderRepo:
res = set(str(Path(PurePosixPath(path))) for path in files) res = set(str(Path(PurePosixPath(path))) for path in files)
return res return res
def is_dirty(self):
return self.repo.is_dirty()