wip: Added methods to get context from history and commit message.

This commit is contained in:
Paul Gauthier 2023-05-11 09:36:34 -07:00
parent adbdbbcb74
commit db2e062634

View file

@ -414,69 +414,22 @@ class Coder:
return edited
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
repo = self.repo
if not repo:
return
if not repo.is_dirty():
return
def get_dirty_files(file_list):
dirty_files = []
relative_dirty_files = []
for fname in file_list:
relative_fname = os.path.relpath(fname, repo.working_tree_dir)
if self.pretty:
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
else:
these_diffs = repo.git.diff("HEAD", relative_fname)
if these_diffs:
dirty_files.append(fname)
relative_dirty_files.append(relative_fname)
return dirty_files, relative_dirty_files
if which == "repo_files":
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
dirty_fnames, relative_dirty_fnames = get_dirty_files(all_files)
elif which == "chat_files":
dirty_fnames, relative_dirty_fnames = get_dirty_files(self.abs_fnames)
else:
raise ValueError(f"Invalid value for 'which': {which}")
diffs = ""
for (abs_fname,relative_fname) in zip(dirty_fnames, relative_dirty_fnames):
if self.pretty:
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
else:
these_diffs = repo.git.diff("HEAD", relative_fname)
diffs += these_diffs + "\n"
if self.show_diffs or ask:
self.console.print(Text(diffs))
diffs = "# Diffs:\n" + diffs
# for fname in dirty_fnames:
# self.console.print(f"[red] {fname}")
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def get_commit_message(self, diffs, context):
diffs = "# Diffs:\n" + diffs
messages = [
dict(role="system", content=prompts.commit_system),
dict(role="user", content=context + diffs),
]
# if history:
# self.show_messages(messages, "commit")
commit_message, interrupted = self.send(
messages,
model="gpt-3.5-turbo",
@ -491,6 +444,46 @@ class Coder:
)
return
return commit_message
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
repo = self.repo
if not repo:
return
if not repo.is_dirty():
return
def get_dirty_files(file_list):
diffs = ""
relative_dirty_files = []
for fname in file_list:
relative_fname = os.path.relpath(fname, repo.working_tree_dir)
if self.pretty:
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
else:
these_diffs = repo.git.diff("HEAD", relative_fname)
if these_diffs:
relative_dirty_files.append(relative_fname)
diffs += these_diffs + "\n"
return relative_dirty_files, diffs
if which == "repo_files":
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
relative_dirty_fnames,diffs = get_dirty_files(all_files)
elif which == "chat_files":
relative_dirty_fnames,diffs = get_dirty_files(self.abs_fnames)
else:
raise ValueError(f"Invalid value for 'which': {which}")
if self.show_diffs or ask:
self.console.print(Text(diffs))
context = self.get_context_from_history(history)
commit_message = self.get_commit_message(diffs, context)
if prefix:
commit_message = prefix + commit_message