commit wip

This commit is contained in:
Paul Gauthier 2023-05-08 11:41:51 -07:00
parent b0d2fafdd5
commit 0f040949c4
2 changed files with 67 additions and 12 deletions

View file

@ -17,6 +17,7 @@ from tqdm import tqdm
from pathlib import Path
import os
import pygit2
import openai
from dump import dump
@ -233,12 +234,12 @@ class Coder:
interrupted = False
try:
if show_progress:
if show_progress is not None:
self.show_send_progress(completion, show_progress)
elif self.pretty:
elif self.pretty and show_progress:
self.show_send_output_color(completion)
else:
self.show_send_output_plain(completion)
self.show_send_output_plain(completion, False)
except KeyboardInterrupt:
interrupted = True
@ -259,7 +260,7 @@ class Coder:
pbar.update(show_progress)
pbar.close()
def show_send_output_plain(self, completion):
def show_send_output_plain(self, completion, show_output=True):
self.resp = ""
for chunk in completion:
@ -271,8 +272,9 @@ class Coder:
except AttributeError:
continue
sys.stdout.write(text)
sys.stdout.flush()
if show_output:
sys.stdout.write(text)
sys.stdout.flush()
def show_send_output_color(self, completion):
self.resp = ""
@ -402,11 +404,64 @@ class Coder:
return res
def commit(self, message_history):
# _messages = [
# dict(role="system", content=prompts.commit_system),
# ]
pass
def commit(self, message_history, prefix=None):
repo_paths = set(pygit2.discover_repository(fname) for fname in self.fnames)
if len(repo_paths) > 1:
repo_paths = " ".join(repo_paths)
raise ValueError(f"Files must all be in one git repo, not: {repo_paths}")
repo = pygit2.Repository(repo_paths.pop())
"""
index = repo.index
index.read()
for patch in index.diff_to_workdir():
print(patch.text)
"""
patches = repo.diff("HEAD")
if not len(patches):
return
context = ""
if message_history:
context += "# Context:\n"
for msg in message_history:
context += msg["role"].upper() + ": " + msg["content"] + "\n"
diffs = "# Diffs:\n"
for diff in patches:
dump(dir(diff))
dump(diff.delta.new_file.path)
diffs += diff.text
diffs += "\n\n"
if not diffs:
return
messages = [
dict(role="system", content=prompts.commit_system),
dict(role="user", content=context + diffs),
]
self.show_messages(messages, "commit")
commit_message, interrupted = self.send(
messages,
model="gpt-3.5-turbo",
show_progress=None,
)
commit_message = commit_message.strip()
if interrupted:
raise KeyboardInterrupt
if prefix:
commit_message = prefix + commit_message
print(commit_message)
def main():