This commit is contained in:
Paul Gauthier 2023-05-07 20:24:59 -07:00
parent eb79dd2760
commit d61ab51a74

129
coder.py
View file

@ -1,7 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
# This is a Python script that uses OpenAI's GPT-3 to modify code based on user requests.
import os import os
import sys import sys
import copy import copy
@ -25,7 +23,7 @@ from dump import dump
import prompts import prompts
history_file = '.coder.history' history_file = ".coder.history"
try: try:
readline.read_history_file(history_file) readline.read_history_file(history_file)
except FileNotFoundError: except FileNotFoundError:
@ -42,15 +40,16 @@ def find_index(list1, list2):
return i return i
return -1 return -1
class Coder: class Coder:
fnames = dict() fnames = dict()
last_modified = 0 last_modified = 0
def __init__(self, use_gpt_4): def __init__(self, use_gpt_4):
if use_gpt_4: if use_gpt_4:
self.main_model = 'gpt-4' self.main_model = "gpt-4"
else: else:
self.main_model = 'gpt-3.5-turbo' self.main_model = "gpt-3.5-turbo"
def add_file(self, fname): def add_file(self, fname):
self.fnames[fname] = Path(fname).stat().st_mtime self.fnames[fname] = Path(fname).stat().st_mtime
@ -64,28 +63,27 @@ class Coder:
self.request_prompt = prompt self.request_prompt = prompt
def quoted_file(self, fname): def quoted_file(self, fname):
prompt = '\n' prompt = "\n"
prompt += fname prompt += fname
prompt += '\n```\n' prompt += "\n```\n"
prompt += Path(fname).read_text() prompt += Path(fname).read_text()
prompt += '\n```\n' prompt += "\n```\n"
return prompt return prompt
def get_files_content(self): def get_files_content(self):
prompt = '' prompt = ""
for fname in self.fnames: for fname in self.fnames:
prompt += self.quoted_file(fname) prompt += self.quoted_file(fname)
return prompt return prompt
def get_input(self): def get_input(self):
print() print()
print('='*60) print("=" * 60)
inp = '' inp = ""
num_control_c = 0 num_control_c = 0
while not inp.strip(): while not inp.strip():
try: try:
inp = input('> ') inp = input("> ")
except EOFError: except EOFError:
return return
except KeyboardInterrupt: except KeyboardInterrupt:
@ -93,7 +91,7 @@ class Coder:
print() print()
if num_control_c >= 2: if num_control_c >= 2:
return return
print('^C again to quit') print("^C again to quit")
print() print()
@ -101,14 +99,13 @@ class Coder:
return inp return inp
def set_files_messages(self, did_edits=False): def set_files_messages(self, did_edits=False):
last_modified = max(Path(fname).stat().st_mtime for fname in self.fnames) last_modified = max(Path(fname).stat().st_mtime for fname in self.fnames)
if last_modified <= self.last_modified: if last_modified <= self.last_modified:
return return
did_edits = (self.last_modified > 0) did_edits = self.last_modified > 0
self.last_modified = last_modified self.last_modified = last_modified
print('Reloading files...') print("Reloading files...")
if did_edits: if did_edits:
files_content = prompts.files_content_prefix_edited files_content = prompts.files_content_prefix_edited
@ -119,8 +116,8 @@ class Coder:
files_content += prompts.files_content_suffix files_content += prompts.files_content_suffix
self.files_messages = [ self.files_messages = [
dict(role = 'user', content = files_content), dict(role="user", content=files_content),
dict(role = 'assistant', content = "Ok."), dict(role="assistant", content="Ok."),
] ]
return True return True
@ -141,7 +138,7 @@ class Coder:
self.cur_messages = [] self.cur_messages = []
self.cur_messages += [ self.cur_messages += [
dict(role = 'user', content = inp), dict(role="user", content=inp),
] ]
# self.show_messages(self.done_messages, "done") # self.show_messages(self.done_messages, "done")
@ -149,18 +146,18 @@ class Coder:
# self.show_messages(self.cur_messages, "cur") # self.show_messages(self.cur_messages, "cur")
messages = [ messages = [
dict(role = 'system', content = prompts.main_system), dict(role="system", content=prompts.main_system),
] ]
messages += self.done_messages messages += self.done_messages
messages += self.files_messages messages += self.files_messages
messages += self.cur_messages messages += self.cur_messages
self.show_messages(messages, 'all') self.show_messages(messages, "all")
content = self.send(messages) content = self.send(messages)
self.cur_messages += [ self.cur_messages += [
dict(role = 'assistant', content = content), dict(role="assistant", content=content),
] ]
print() print()
@ -181,13 +178,13 @@ class Coder:
self.cur_messages = [] self.cur_messages = []
def show_messages(self, messages, title): def show_messages(self, messages, title):
print(title.upper(), '*' * 50) print(title.upper(), "*" * 50)
for msg in messages: for msg in messages:
print() print()
print('-' * 50) print("-" * 50)
role = msg['role'].upper() role = msg["role"].upper()
content = msg['content'].splitlines() content = msg["content"].splitlines()
for line in content: for line in content:
print(role, line) print(role, line)
@ -224,18 +221,18 @@ class Coder:
pbar.update(show_progress) pbar.update(show_progress)
pbar.close() pbar.close()
resp = ''.join(resp) resp = "".join(resp)
return resp return resp
def show_send_output_plain(self, completion): def show_send_output_plain(self, completion):
resp = '' resp = ""
in_diff = False in_diff = False
diff_lines = [] diff_lines = []
partial_line = '' partial_line = ""
for chunk in completion: for chunk in completion:
if chunk.choices[0].finish_reason not in (None, 'stop'): if chunk.choices[0].finish_reason not in (None, "stop"):
dump(chunk.choices[0].finish_reason) dump(chunk.choices[0].finish_reason)
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
@ -247,7 +244,7 @@ class Coder:
sys.stdout.flush() sys.stdout.flush()
# disabled # disabled
if False and '```' in resp: if False and "```" in resp:
return resp return resp
return resp return resp
@ -261,12 +258,12 @@ class Coder:
def print_lines(): def print_lines():
if not diff_lines: if not diff_lines:
return return
code = '\n'.join(diff_lines) code = "\n".join(diff_lines)
lexer = lexers.guess_lexer(code) lexer = lexers.guess_lexer(code)
code = highlight(code, lexer, formatter) code = highlight(code, lexer, formatter)
print(code, end='') print(code, end="")
partial_line = '' partial_line = ""
for chunk in completion: for chunk in completion:
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
@ -274,18 +271,18 @@ class Coder:
except AttributeError: except AttributeError:
continue continue
lines = (partial_line + text) lines = partial_line + text
lines = lines.split('\n') lines = lines.split("\n")
partial_line = lines.pop() partial_line = lines.pop()
for line in lines: for line in lines:
check = line.rstrip() check = line.rstrip()
if check == '>>>>>>> UPDATED': if check == ">>>>>>> UPDATED":
print_lines() print_lines()
in_diff = False in_diff = False
diff_lines = [] diff_lines = []
if check == '=======': if check == "=======":
print_lines() print_lines()
diff_lines = [] diff_lines = []
print(line) print(line)
@ -294,7 +291,7 @@ class Coder:
else: else:
print(line) print(line)
if line.strip() == '<<<<<<< ORIGINAL': if line.strip() == "<<<<<<< ORIGINAL":
in_diff = True in_diff = True
diff_lines = [] diff_lines = []
@ -302,13 +299,14 @@ class Coder:
if partial_line: if partial_line:
print(partial_line) print(partial_line)
return ''.join(resp) return "".join(resp)
pattern = re.compile(
pattern = re.compile(r'(\S+)\s+(```)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED', re.MULTILINE | re.DOTALL) r"(\S+)\s+(```)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED",
re.MULTILINE | re.DOTALL,
)
def update_files(self, content, inp): def update_files(self, content, inp):
edited = set() edited = set()
for match in self.pattern.finditer(content): for match in self.pattern.finditer(content):
path, _, original, updated = match.groups() path, _, original, updated = match.groups()
@ -329,7 +327,7 @@ class Coder:
# does it want to make a new file? # does it want to make a new file?
if not fname.exists() and not before_text: if not fname.exists() and not before_text:
print('Creating empty file:', fname) print("Creating empty file:", fname)
fname.touch() fname.touch()
content = fname.read_text().splitlines() content = fname.read_text().splitlines()
@ -348,15 +346,15 @@ class Coder:
new_content = content[:where] new_content = content[:where]
new_content += after_text.splitlines() new_content += after_text.splitlines()
new_content += content[where + len(before_lines) :] new_content += content[where + len(before_lines) :]
new_content = '\n'.join(new_content) + '\n' new_content = "\n".join(new_content) + "\n"
fname.write_text(new_content) fname.write_text(new_content)
print('Applied edit to', fname) print("Applied edit to", fname)
return True return True
def do_gpt_powered_replace(self, fname, edit, request): def do_gpt_powered_replace(self, fname, edit, request):
model = 'gpt-3.5-turbo' model = "gpt-3.5-turbo"
print(f'Asking {model} to apply ambiguous edit to {fname}...') print(f"Asking {model} to apply ambiguous edit to {fname}...")
fname = Path(fname) fname = Path(fname)
content = fname.read_text() content = fname.read_text()
@ -368,10 +366,12 @@ class Coder:
) )
messages = [ messages = [
dict(role = 'system', content = prompts.editor_system), dict(role="system", content=prompts.editor_system),
dict(role = 'user', content = prompt), dict(role="user", content=prompt),
] ]
res = self.send(messages, show_progress = len(content) + len(edit)/2, model=model) res = self.send(
messages, show_progress=len(content) + len(edit) / 2, model=model
)
res = self.strip_quoted_wrapping(res, fname) res = self.strip_quoted_wrapping(res, fname)
fname.write_text(res) fname.write_text(res)
@ -384,21 +384,27 @@ class Coder:
if fname and res[0].strip().endswith(Path(fname).name): if fname and res[0].strip().endswith(Path(fname).name):
res = res[1:] res = res[1:]
if res[0].startswith('```') and res[-1].startswith('```'): if res[0].startswith("```") and res[-1].startswith("```"):
res = res[1:-1] res = res[1:-1]
res = '\n'.join(res) res = "\n".join(res)
if res and res[-1] != '\n': if res and res[-1] != "\n":
res += '\n' res += "\n"
return res return res
def main(): def main():
parser = argparse.ArgumentParser(description="Chat with GPT about code")
parser = argparse.ArgumentParser(description='Chat with GPT about code') parser.add_argument(
parser.add_argument('files', metavar='FILE', nargs='+', help='a list of source code files') "files", metavar="FILE", nargs="+", help="a list of source code files"
parser.add_argument('-3', '--gpt-3-5-turbo', action='store_true', help='Only use gpt-3.5-turbo, not gpt-4') )
parser.add_argument(
"-3",
"--gpt-3-5-turbo",
action="store_true",
help="Only use gpt-3.5-turbo, not gpt-4",
)
args = parser.parse_args() args = parser.parse_args()
@ -412,6 +418,7 @@ def main():
coder.run() coder.run()
if __name__ == '__main__':
if __name__ == "__main__":
status = main() status = main()
sys.exit(status) sys.exit(status)