color wip

This commit is contained in:
Paul Gauthier 2023-04-08 14:52:03 -07:00
parent 837efaa20f
commit 92e2f8aa72

View file

@ -13,6 +13,7 @@ from tqdm import tqdm
from pathlib import Path
from collections import defaultdict
from pygments import highlight, lexers, formatters
import os
import openai
@ -25,6 +26,8 @@ try:
except FileNotFoundError:
pass
formatter = formatters.TerminalFormatter()
openai.api_key = os.getenv("OPENAI_API_KEY")
prompt_webdev = '''
@ -260,6 +263,19 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
def show_send_output(self, completion):
resp = []
in_diff = False
diff_lines = []
lexer = None
def print_lines():
if not diff_lines:
return
code = '\n'.join(diff_lines)
lexer = lexers.guess_lexer(code)
code = highlight(code, lexer, formatter)
print(code, end='')
partial_line = ''
for chunk in completion:
try:
text = chunk.choices[0].delta.content
@ -267,11 +283,45 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
except AttributeError:
continue
sys.stdout.write(text)
sys.stdout.flush()
lines = (partial_line + text)
lines = lines.split('\n')
partial_line = lines.pop()
resp = ''.join(resp)
return resp
for line in lines:
check = line.rstrip()
if check == '>>>>>>> UPDATED':
print_lines()
in_diff = False
lexer = None
diff_lines = []
if check == '=======':
if len(diff_lines) >= 3:
print_lines()
diff_lines = []
print(line)
elif in_diff:
if lexer is None:
diff_lines.append(line)
if len(diff_lines) >= 3:
print_lines()
diff_lines = []
else:
code = highlight(line, lexer, formatter)
print(code, end='')
else:
print(line)
if line.strip() == '<<<<<<< ORIGINAL':
in_diff = True
lexer = None
diff_lines = []
if partial_line:
print(partial_line)
return ''.join(resp)
pattern = re.compile(r'^(\S+)\n<<<<<<< ORIGINAL\n(.+?)\n=======\n(.+?)\n>>>>>>> UPDATED$', re.MULTILINE | re.DOTALL)