This commit is contained in:
Paul Gauthier 2023-07-21 12:06:32 -03:00
parent 23beb7cb5d
commit 296e7614c4
3 changed files with 47 additions and 82 deletions

View file

@ -379,12 +379,7 @@ class Coder:
) )
if self.should_dirty_commit(inp): if self.should_dirty_commit(inp):
self.io.tool_output("Git repo has uncommitted changes, preparing commit...") self.dirty_commit()
self.commit(ask=True, which="repo_files", pretty=self.pretty)
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
if inp.strip(): if inp.strip():
self.io.tool_output("Use up-arrow to retry previous command:", inp) self.io.tool_output("Use up-arrow to retry previous command:", inp)
return return
@ -399,6 +394,27 @@ class Coder:
return self.send_new_user_message(inp) return self.send_new_user_message(inp)
def dirty_commit(self):
self.io.tool_output("Git repo has uncommitted changes.")
self.repo.show_diffs(self.pretty)
self.last_asked_for_commit_time = self.get_last_modified()
res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?",
default="y",
).strip()
if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.")
return
if res.lower() in ["y", "yes"]:
message = None
else:
message = res.strip()
self.commit(message=message)
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
def fmt_system_reminder(self): def fmt_system_reminder(self):
prompt = self.gpt_prompts.system_reminder prompt = self.gpt_prompts.system_reminder
prompt = prompt.format(fence=self.fence) prompt = prompt.format(fence=self.fence)
@ -508,7 +524,7 @@ class Coder:
def auto_commit(self): def auto_commit(self):
context = self.get_context_from_history(self.cur_messages) context = self.get_context_from_history(self.cur_messages)
res = self.commit(context=context, prefix="aider: ", pretty=self.pretty) res = self.commit(context=context, prefix="aider: ")
if res: if res:
commit_hash, commit_message = res commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash self.last_aider_commit_hash = commit_hash

View file

@ -216,7 +216,9 @@ class Commands:
commits = f"{self.coder.last_aider_commit_hash}~1" commits = f"{self.coder.last_aider_commit_hash}~1"
diff = self.coder.get_diffs( diff = self.coder.get_diffs(
commits, self.coder.last_aider_commit_hash, pretty=self.coder.pretty self.coder.pretty,
commits,
self.coder.last_aider_commit_hash,
) )
# don't use io.tool_output() because we don't want to log or further colorize # don't use io.tool_output() because we don't want to log or further colorize

View file

@ -50,7 +50,6 @@ class AiderRepo:
self.root = utils.safe_abs_path(self.repo.working_tree_dir) self.root = utils.safe_abs_path(self.repo.working_tree_dir)
def ___(self, fnames): def ___(self, fnames):
# TODO! # TODO!
self.abs_fnames.add(str(fname)) self.abs_fnames.add(str(fname))
@ -81,91 +80,27 @@ class AiderRepo:
else: else:
self.io.tool_error("Skipped adding new files to the git repo.") self.io.tool_error("Skipped adding new files to the git repo.")
def commit( def commit(self, context=None, prefix=None, message=None):
self, context=None, prefix=None, ask=False, message=None, which="chat_files", pretty=False if not self.repo.is_dirty():
):
## TODO!
repo = self.repo
if not repo:
return return
if not repo.is_dirty():
return
def get_dirty_files_and_diffs(file_list):
diffs = ""
relative_dirty_files = []
for fname in file_list:
relative_fname = self.get_rel_fname(fname)
relative_dirty_files.append(relative_fname)
try:
current_branch_commit_count = len(
list(self.repo.iter_commits(self.repo.active_branch))
)
except git.exc.GitCommandError:
current_branch_commit_count = None
if not current_branch_commit_count:
continue
these_diffs = self.get_diffs(pretty, "HEAD", "--", relative_fname)
if these_diffs:
diffs += these_diffs + "\n"
return relative_dirty_files, diffs
if which == "repo_files":
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
relative_dirty_fnames, diffs = get_dirty_files_and_diffs(all_files)
elif which == "chat_files":
relative_dirty_fnames, diffs = get_dirty_files_and_diffs(self.abs_fnames)
else:
raise ValueError(f"Invalid value for 'which': {which}")
if self.show_diffs or ask:
# don't use io.tool_output() because we don't want to log or further colorize
print(diffs)
if message: if message:
commit_message = message commit_message = message
else: else:
diffs = self.get_diffs(False)
commit_message = self.get_commit_message(diffs, context) commit_message = self.get_commit_message(diffs, context)
if not commit_message: if not commit_message:
commit_message = "work in progress" commit_message = "(no commit message provided)"
if prefix: if prefix:
commit_message = prefix + commit_message commit_message = prefix + commit_message
if ask: if context:
if which == "repo_files": commit_message = commit_message + "\n\n# Aider chat conversation:\n\n" + context
self.io.tool_output("Git repo has uncommitted changes.")
else:
self.io.tool_output("Files have uncommitted changes.")
res = self.io.prompt_ask( self.repo.git.commit("-a", "-m", commit_message, "--no-verify")
"Commit before the chat proceeds [y/n/commit message]?", commit_hash = self.repo.head.commit.hexsha[:7]
default=commit_message,
).strip()
self.last_asked_for_commit_time = self.get_last_modified()
self.io.tool_output()
if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.")
return
if res.lower() not in ["y", "yes"] and res:
commit_message = res
repo.git.add(*relative_dirty_fnames)
full_commit_message = commit_message + "\n\n# Aider chat conversation:\n\n" + context
repo.git.commit("-m", full_commit_message, "--no-verify")
commit_hash = repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}") self.io.tool_output(f"Commit {commit_hash} {commit_message}")
return commit_hash, commit_message return commit_hash, commit_message
@ -197,7 +132,7 @@ class AiderRepo:
functions=None, functions=None,
stream=False, stream=False,
) )
commit_message = completion.choices[0].message.content commit_message = response.choices[0].message.content
except (AttributeError, openai.error.InvalidRequestError): except (AttributeError, openai.error.InvalidRequestError):
self.io.tool_error(f"Failed to generate commit message using {models.GPT35.name}") self.io.tool_error(f"Failed to generate commit message using {models.GPT35.name}")
return return
@ -215,6 +150,18 @@ class AiderRepo:
diffs = self.repo.git.diff(*args) diffs = self.repo.git.diff(*args)
return diffs return diffs
def show_diffs(self, pretty):
try:
current_branch_has_commits = any(self.repo.iter_commits(self.repo.active_branch))
except git.exc.GitCommandError:
current_branch_has_commits = False
if not current_branch_has_commits:
return
diffs = self.get_diffs(pretty, "HEAD")
print(diffs)
def get_tracked_files(self): def get_tracked_files(self):
if not self.repo: if not self.repo:
return [] return []