This commit is contained in:
Paul Gauthier 2023-04-08 19:30:30 -07:00
parent 1c9025716d
commit 5e1769b040

View file

@ -9,6 +9,8 @@ import random
import json import json
import re import re
import readline import readline
import traceback
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
@ -42,16 +44,28 @@ FOR EACH CHANGE TO THE CODE, DESCRIBE IT USING THIS FORMAT:
path/to/filename.ext path/to/filename.ext
<<<<<<< ORIGINAL <<<<<<< ORIGINAL
a chunk of the **exact** lines original lines
from the current file that needs to be changed to search for
MUST BE THE EXACT LINES FROM THE CURRENT FILE
======= =======
new lines to replace new lines to replace
the original chunk the original chunk
>>>>>>> UPDATED >>>>>>> UPDATED
ONLY USE THIS ORIGINAL/UPDATED FORMAT TO DESCRIBE CODE CHANGES! ONLY USE THIS ORIGINAL/UPDATED FORMAT TO DESCRIBE CODE CHANGES!
DO NOT USE ``` DELIMITERS!
Example:
foo.py
<<<<<<< ORIGINAL
print(1+1)
=======
print(2+2)
>>>>>>> UPDATED
To add new code, anchor it by including 2-3 lines in the ORIGINAL and UPDATED portions of the diff.
Don't just output the ENTIRE file. Turn it into an edit.
''' '''
prompt_comments = ''' prompt_comments = '''
@ -229,12 +243,13 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
print() print()
print() print()
try: try:
did_edits = self.update_files(content) did_edits = self.update_files(content, inp)
if did_edits: if did_edits:
print() print()
except Exception as err: except Exception as err:
print(err) print(err)
print() print()
traceback.print_exc()
def send(self, messages, show_progress = 0): def send(self, messages, show_progress = 0):
for msg in messages: for msg in messages:
@ -296,6 +311,8 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
sys.stdout.write(text) sys.stdout.write(text)
sys.stdout.flush() sys.stdout.flush()
return ''.join(resp)
def show_send_output_color(self, completion): def show_send_output_color(self, completion):
resp = [] resp = []
@ -348,9 +365,9 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
return ''.join(resp) return ''.join(resp)
pattern = re.compile(r'^(\S+)\n<<<<<<< ORIGINAL\n(.+?)\n=======\n(.+?)\n>>>>>>> UPDATED$', re.MULTILINE | re.DOTALL) pattern = re.compile(r'^(\S+)\n<<<<<<< ORIGINAL\n(.*?)\n=======\n(.*?)\n>>>>>>> UPDATED$', re.MULTILINE | re.DOTALL)
def update_files(self, content): def update_files(self, content, inp):
did_edits = False did_edits = False
for match in self.pattern.finditer(content): for match in self.pattern.finditer(content):
@ -359,13 +376,13 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
if self.do_replace(path, original, updated): if self.do_replace(path, original, updated):
continue continue
edit = match.group() edit = match.group()
self.do_gpt_powered_replace(path, edit) self.do_gpt_powered_replace(path, edit, inp)
return did_edits return did_edits
def do_replace(self, fname, before_text, after_text): def do_replace(self, fname, before_text, after_text):
before_text = self.strip_quoted_wrapping(before_text, fname) before_text = self.strip_quoted_wrapping(before_text, fname)
dump(repr(before_text)) after_text = self.strip_quoted_wrapping(after_text, fname)
fname = Path(fname) fname = Path(fname)
content = fname.read_text().splitlines() content = fname.read_text().splitlines()
@ -385,13 +402,17 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
print('Applied edit to', fname) print('Applied edit to', fname)
return True return True
def do_gpt_powered_replace(self, fname, edit): def do_gpt_powered_replace(self, fname, edit, request):
print(f'Asking GPT to apply ambiguous edit to {fname}...') print(f'Asking GPT to apply ambiguous edit to {fname}...')
print(repr(edit)) print(repr(edit))
fname = Path(fname) fname = Path(fname)
content = fname.read_text() content = fname.read_text()
prompt = f''' prompt = f'''
Apply this change: To complete this request:
{request}
You need to apply this change:
{edit} {edit}
@ -422,6 +443,9 @@ Just the content of the file!
fname.write_text(res) fname.write_text(res)
def strip_quoted_wrapping(self, res, fname=None): def strip_quoted_wrapping(self, res, fname=None):
if not res:
return res
res = res.splitlines() res = res.splitlines()
if fname and res[0].strip().endswith(Path(fname).name): if fname and res[0].strip().endswith(Path(fname).name):