diff --git a/coder.py b/coder.py index 14761787f..e88e5cb7d 100755 --- a/coder.py +++ b/coder.py @@ -13,6 +13,7 @@ from tqdm import tqdm from pathlib import Path from collections import defaultdict +from pygments import highlight, lexers, formatters import os import openai @@ -25,6 +26,8 @@ try: except FileNotFoundError: pass +formatter = formatters.TerminalFormatter() + openai.api_key = os.getenv("OPENAI_API_KEY") prompt_webdev = ''' @@ -260,6 +263,19 @@ MAKE ANY CHANGES BASED OFF THESE FILES! def show_send_output(self, completion): resp = [] + in_diff = False + diff_lines = [] + lexer = None + + def print_lines(): + if not diff_lines: + return + code = '\n'.join(diff_lines) + lexer = lexers.guess_lexer(code) + code = highlight(code, lexer, formatter) + print(code, end='') + + partial_line = '' for chunk in completion: try: text = chunk.choices[0].delta.content @@ -267,11 +283,45 @@ MAKE ANY CHANGES BASED OFF THESE FILES! except AttributeError: continue - sys.stdout.write(text) - sys.stdout.flush() + lines = (partial_line + text) + lines = lines.split('\n') + partial_line = lines.pop() - resp = ''.join(resp) - return resp + for line in lines: + check = line.rstrip() + if check == '>>>>>>> UPDATED': + print_lines() + in_diff = False + lexer = None + diff_lines = [] + + if check == '=======': + if len(diff_lines) >= 3: + print_lines() + diff_lines = [] + + print(line) + elif in_diff: + if lexer is None: + diff_lines.append(line) + if len(diff_lines) >= 3: + print_lines() + diff_lines = [] + else: + code = highlight(line, lexer, formatter) + print(code, end='') + else: + print(line) + + if line.strip() == '<<<<<<< ORIGINAL': + in_diff = True + lexer = None + diff_lines = [] + + if partial_line: + print(partial_line) + + return ''.join(resp) pattern = re.compile(r'^(\S+)\n<<<<<<< ORIGINAL\n(.+?)\n=======\n(.+?)\n>>>>>>> UPDATED$', re.MULTILINE | re.DOTALL)