diff --git a/coder.py b/coder.py index b781448ae..e00eb02c3 100755 --- a/coder.py +++ b/coder.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# This is a Python script that uses OpenAI's GPT-3 to modify code based on user requests. - import os import sys import copy @@ -25,7 +23,7 @@ from dump import dump import prompts -history_file = '.coder.history' +history_file = ".coder.history" try: readline.read_history_file(history_file) except FileNotFoundError: @@ -38,25 +36,26 @@ openai.api_key = os.getenv("OPENAI_API_KEY") def find_index(list1, list2): for i in range(len(list1)): - if list1[i:i+len(list2)] == list2: + if list1[i : i + len(list2)] == list2: return i return -1 + class Coder: fnames = dict() last_modified = 0 def __init__(self, use_gpt_4): if use_gpt_4: - self.main_model = 'gpt-4' + self.main_model = "gpt-4" else: - self.main_model = 'gpt-3.5-turbo' + self.main_model = "gpt-3.5-turbo" def add_file(self, fname): self.fnames[fname] = Path(fname).stat().st_mtime def files_modified(self): - for fname,mtime in self.fnames.items(): + for fname, mtime in self.fnames.items(): if Path(fname).stat().st_mtime != mtime: return True @@ -64,28 +63,27 @@ class Coder: self.request_prompt = prompt def quoted_file(self, fname): - prompt = '\n' + prompt = "\n" prompt += fname - prompt += '\n```\n' + prompt += "\n```\n" prompt += Path(fname).read_text() - prompt += '\n```\n' + prompt += "\n```\n" return prompt def get_files_content(self): - prompt = '' + prompt = "" for fname in self.fnames: prompt += self.quoted_file(fname) return prompt def get_input(self): - print() - print('='*60) - inp = '' + print("=" * 60) + inp = "" num_control_c = 0 while not inp.strip(): try: - inp = input('> ') + inp = input("> ") except EOFError: return except KeyboardInterrupt: @@ -93,22 +91,21 @@ class Coder: print() if num_control_c >= 2: return - print('^C again to quit') + print("^C again to quit") print() readline.write_history_file(history_file) return inp - def set_files_messages(self, did_edits = False): - + def set_files_messages(self, did_edits=False): last_modified = max(Path(fname).stat().st_mtime for fname in self.fnames) if last_modified <= self.last_modified: return - did_edits = (self.last_modified > 0) + did_edits = self.last_modified > 0 self.last_modified = last_modified - print('Reloading files...') + print("Reloading files...") if did_edits: files_content = prompts.files_content_prefix_edited @@ -119,8 +116,8 @@ class Coder: files_content += prompts.files_content_suffix self.files_messages = [ - dict(role = 'user', content = files_content), - dict(role = 'assistant', content = "Ok."), + dict(role="user", content=files_content), + dict(role="assistant", content="Ok."), ] return True @@ -141,26 +138,26 @@ class Coder: self.cur_messages = [] self.cur_messages += [ - dict(role = 'user', content = inp), + dict(role="user", content=inp), ] - #self.show_messages(self.done_messages, "done") - #self.show_messages(self.files_messages, "files") - #self.show_messages(self.cur_messages, "cur") + # self.show_messages(self.done_messages, "done") + # self.show_messages(self.files_messages, "files") + # self.show_messages(self.cur_messages, "cur") messages = [ - dict(role = 'system', content = prompts.main_system), + dict(role="system", content=prompts.main_system), ] messages += self.done_messages messages += self.files_messages messages += self.cur_messages - self.show_messages(messages, 'all') + self.show_messages(messages, "all") content = self.send(messages) self.cur_messages += [ - dict(role = 'assistant', content = content), + dict(role="assistant", content=content), ] print() @@ -181,18 +178,18 @@ class Coder: self.cur_messages = [] def show_messages(self, messages, title): - print(title.upper(), '*' * 50) + print(title.upper(), "*" * 50) for msg in messages: print() - print('-' * 50) - role = msg['role'].upper() - content = msg['content'].splitlines() + print("-" * 50) + role = msg["role"].upper() + content = msg["content"].splitlines() for line in content: print(role, line) - def send(self, messages, model=None, show_progress = 0): - #self.show_messages(messages, "all") + def send(self, messages, model=None, show_progress=0): + # self.show_messages(messages, "all") if not model: model = self.main_model @@ -201,7 +198,7 @@ class Coder: model=model, messages=messages, temperature=0, - stream = True, + stream=True, ) if show_progress: @@ -211,7 +208,7 @@ class Coder: def show_send_progress(self, completion, show_progress): resp = [] - pbar = tqdm(total = show_progress) + pbar = tqdm(total=show_progress) for chunk in completion: try: text = chunk.choices[0].delta.content @@ -224,18 +221,18 @@ class Coder: pbar.update(show_progress) pbar.close() - resp = ''.join(resp) + resp = "".join(resp) return resp def show_send_output_plain(self, completion): - resp = '' + resp = "" in_diff = False diff_lines = [] - partial_line = '' + partial_line = "" for chunk in completion: - if chunk.choices[0].finish_reason not in (None, 'stop'): + if chunk.choices[0].finish_reason not in (None, "stop"): dump(chunk.choices[0].finish_reason) try: text = chunk.choices[0].delta.content @@ -247,7 +244,7 @@ class Coder: sys.stdout.flush() # disabled - if False and '```' in resp: + if False and "```" in resp: return resp return resp @@ -261,12 +258,12 @@ class Coder: def print_lines(): if not diff_lines: return - code = '\n'.join(diff_lines) + code = "\n".join(diff_lines) lexer = lexers.guess_lexer(code) code = highlight(code, lexer, formatter) - print(code, end='') + print(code, end="") - partial_line = '' + partial_line = "" for chunk in completion: try: text = chunk.choices[0].delta.content @@ -274,18 +271,18 @@ class Coder: except AttributeError: continue - lines = (partial_line + text) - lines = lines.split('\n') + lines = partial_line + text + lines = lines.split("\n") partial_line = lines.pop() for line in lines: check = line.rstrip() - if check == '>>>>>>> UPDATED': + if check == ">>>>>>> UPDATED": print_lines() in_diff = False diff_lines = [] - if check == '=======': + if check == "=======": print_lines() diff_lines = [] print(line) @@ -294,7 +291,7 @@ class Coder: else: print(line) - if line.strip() == '<<<<<<< ORIGINAL': + if line.strip() == "<<<<<<< ORIGINAL": in_diff = True diff_lines = [] @@ -302,13 +299,14 @@ class Coder: if partial_line: print(partial_line) - return ''.join(resp) + return "".join(resp) - - pattern = re.compile(r'(\S+)\s+(```)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED', re.MULTILINE | re.DOTALL) + pattern = re.compile( + r"(\S+)\s+(```)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED", + re.MULTILINE | re.DOTALL, + ) def update_files(self, content, inp): - edited = set() for match in self.pattern.finditer(content): path, _, original, updated = match.groups() @@ -329,7 +327,7 @@ class Coder: # does it want to make a new file? if not fname.exists() and not before_text: - print('Creating empty file:', fname) + print("Creating empty file:", fname) fname.touch() content = fname.read_text().splitlines() @@ -347,31 +345,33 @@ class Coder: new_content = content[:where] new_content += after_text.splitlines() - new_content += content[where+len(before_lines):] - new_content = '\n'.join(new_content) + '\n' + new_content += content[where + len(before_lines) :] + new_content = "\n".join(new_content) + "\n" fname.write_text(new_content) - print('Applied edit to', fname) + print("Applied edit to", fname) return True def do_gpt_powered_replace(self, fname, edit, request): - model = 'gpt-3.5-turbo' - print(f'Asking {model} to apply ambiguous edit to {fname}...') + model = "gpt-3.5-turbo" + print(f"Asking {model} to apply ambiguous edit to {fname}...") fname = Path(fname) content = fname.read_text() prompt = prompts.editor_user.format( - request = request, - edit = edit, - fname = fname, - content = content, - ) + request=request, + edit=edit, + fname=fname, + content=content, + ) messages = [ - dict(role = 'system', content = prompts.editor_system), - dict(role = 'user', content = prompt), + dict(role="system", content=prompts.editor_system), + dict(role="user", content=prompt), ] - res = self.send(messages, show_progress = len(content) + len(edit)/2, model=model) + res = self.send( + messages, show_progress=len(content) + len(edit) / 2, model=model + ) res = self.strip_quoted_wrapping(res, fname) fname.write_text(res) @@ -384,21 +384,27 @@ class Coder: if fname and res[0].strip().endswith(Path(fname).name): res = res[1:] - if res[0].startswith('```') and res[-1].startswith('```'): + if res[0].startswith("```") and res[-1].startswith("```"): res = res[1:-1] - res = '\n'.join(res) - if res and res[-1] != '\n': - res += '\n' + res = "\n".join(res) + if res and res[-1] != "\n": + res += "\n" return res def main(): - - parser = argparse.ArgumentParser(description='Chat with GPT about code') - parser.add_argument('files', metavar='FILE', nargs='+', help='a list of source code files') - parser.add_argument('-3', '--gpt-3-5-turbo', action='store_true', help='Only use gpt-3.5-turbo, not gpt-4') + parser = argparse.ArgumentParser(description="Chat with GPT about code") + parser.add_argument( + "files", metavar="FILE", nargs="+", help="a list of source code files" + ) + parser.add_argument( + "-3", + "--gpt-3-5-turbo", + action="store_true", + help="Only use gpt-3.5-turbo, not gpt-4", + ) args = parser.parse_args() @@ -412,6 +418,7 @@ def main(): coder.run() -if __name__ == '__main__': + +if __name__ == "__main__": status = main() sys.exit(status)