Refactor commit method to accept history parameter.

This commit is contained in:
Paul Gauthier 2023-05-08 17:21:36 -07:00
parent 8b920f778e
commit e5cbc5cc4c

View file

@ -32,7 +32,6 @@ except FileNotFoundError:
openai.api_key = os.getenv("OPENAI_API_KEY") openai.api_key = os.getenv("OPENAI_API_KEY")
def find_index(list1, list2): def find_index(list1, list2):
for i in range(len(list1)): for i in range(len(list1)):
if list1[i : i + len(list2)] == list2: if list1[i : i + len(list2)] == list2:
@ -154,6 +153,8 @@ class Coder:
self.num_control_c = 0 self.num_control_c = 0
if self.check_for_local_edits(): if self.check_for_local_edits():
self.commit(ask=True)
# files changed, move cur messages back behind the files messages # files changed, move cur messages back behind the files messages
self.done_messages += self.cur_messages self.done_messages += self.cur_messages
self.done_messages += [ self.done_messages += [
@ -206,7 +207,7 @@ class Coder:
dict(role="user", content=prompts.files_content_gpt_edits), dict(role="user", content=prompts.files_content_gpt_edits),
dict(role="assistant", content="Ok."), dict(role="assistant", content="Ok."),
] ]
self.commit(self.cur_messages) self.commit(history=self.cur_messages)
self.cur_messages = [] self.cur_messages = []
return True return True
@ -406,7 +407,7 @@ class Coder:
return res return res
def commit(self, message_history, prefix=None, ask=False): def commit(self, history=None, prefix=None, ask=False):
repo_paths = set( repo_paths = set(
git.Repo(fname, search_parent_directories=True).git_dir git.Repo(fname, search_parent_directories=True).git_dir
for fname in self.fnames for fname in self.fnames
@ -442,9 +443,9 @@ class Coder:
# self.console.print(f"[red] {fname}") # self.console.print(f"[red] {fname}")
context = "" context = ""
if message_history: if history:
context += "# Context:\n" context += "# Context:\n"
for msg in message_history: for msg in history:
context += msg["role"].upper() + ": " + msg["content"] + "\n" context += msg["role"].upper() + ": " + msg["content"] + "\n"
messages = [ messages = [
@ -452,7 +453,8 @@ class Coder:
dict(role="user", content=context + diffs), dict(role="user", content=context + diffs),
] ]
# self.show_messages(messages, "commit") if history:
self.show_messages(messages, "commit")
commit_message, interrupted = self.send( commit_message, interrupted = self.send(
messages, messages,
@ -523,7 +525,7 @@ def main():
pretty = args.pretty pretty = args.pretty
coder = Coder(use_gpt_4, fnames, pretty) coder = Coder(use_gpt_4, fnames, pretty)
coder.commit("", ask=not args.commit) coder.commit(ask=not args.commit, prefix="WIP: ")
if args.apply: if args.apply:
with open(args.apply, "r") as f: with open(args.apply, "r") as f: