color wip

This commit is contained in:
Paul Gauthier 2023-04-08 14:52:03 -07:00
parent 837efaa20f
commit 92e2f8aa72

View file

@ -13,6 +13,7 @@ from tqdm import tqdm
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
from pygments import highlight, lexers, formatters
import os import os
import openai import openai
@ -25,6 +26,8 @@ try:
except FileNotFoundError: except FileNotFoundError:
pass pass
formatter = formatters.TerminalFormatter()
openai.api_key = os.getenv("OPENAI_API_KEY") openai.api_key = os.getenv("OPENAI_API_KEY")
prompt_webdev = ''' prompt_webdev = '''
@ -260,6 +263,19 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
def show_send_output(self, completion): def show_send_output(self, completion):
resp = [] resp = []
in_diff = False
diff_lines = []
lexer = None
def print_lines():
if not diff_lines:
return
code = '\n'.join(diff_lines)
lexer = lexers.guess_lexer(code)
code = highlight(code, lexer, formatter)
print(code, end='')
partial_line = ''
for chunk in completion: for chunk in completion:
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
@ -267,11 +283,45 @@ MAKE ANY CHANGES BASED OFF THESE FILES!
except AttributeError: except AttributeError:
continue continue
sys.stdout.write(text) lines = (partial_line + text)
sys.stdout.flush() lines = lines.split('\n')
partial_line = lines.pop()
resp = ''.join(resp) for line in lines:
return resp check = line.rstrip()
if check == '>>>>>>> UPDATED':
print_lines()
in_diff = False
lexer = None
diff_lines = []
if check == '=======':
if len(diff_lines) >= 3:
print_lines()
diff_lines = []
print(line)
elif in_diff:
if lexer is None:
diff_lines.append(line)
if len(diff_lines) >= 3:
print_lines()
diff_lines = []
else:
code = highlight(line, lexer, formatter)
print(code, end='')
else:
print(line)
if line.strip() == '<<<<<<< ORIGINAL':
in_diff = True
lexer = None
diff_lines = []
if partial_line:
print(partial_line)
return ''.join(resp)
pattern = re.compile(r'^(\S+)\n<<<<<<< ORIGINAL\n(.+?)\n=======\n(.+?)\n>>>>>>> UPDATED$', re.MULTILINE | re.DOTALL) pattern = re.compile(r'^(\S+)\n<<<<<<< ORIGINAL\n(.+?)\n=======\n(.+?)\n>>>>>>> UPDATED$', re.MULTILINE | re.DOTALL)