From 8ac27839dff069328f513ae5b48d8efbe79d1289 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Mon, 8 May 2023 08:49:28 -0700 Subject: [PATCH] Gracefully handle ^C and let gpt see that the user interrupted it --- coder.py | 170 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 90 insertions(+), 80 deletions(-) diff --git a/coder.py b/coder.py index c97c4ecac..70cb84324 100755 --- a/coder.py +++ b/coder.py @@ -89,7 +89,6 @@ class Coder: print() inp = "" - num_control_c = 0 if self.pretty: print(Fore.GREEN, end="\r") else: @@ -100,12 +99,6 @@ class Coder: inp = input("> ") except EOFError: return - except KeyboardInterrupt: - num_control_c += 1 - self.console.print() - if num_control_c >= 2: - return - self.console.print("[bold red]^C again to quit") ### if self.pretty: @@ -145,66 +138,81 @@ class Coder: self.done_messages = [] self.cur_messages = [] + self.num_control_c = 0 + while True: - inp = self.get_input() - if inp is None: - return - - if self.check_for_local_edits(): - # files changed, move cur messages back behind the files messages - self.done_messages += self.cur_messages - self.done_messages += [ - dict(role="user", content=prompts.files_content_local_edits), - dict(role="assistant", content="Ok."), - ] - self.cur_messages = [] - - self.cur_messages += [ - dict(role="user", content=inp), - ] - - # self.show_messages(self.done_messages, "done") - # self.show_messages(self.files_messages, "files") - # self.show_messages(self.cur_messages, "cur") - - messages = [ - dict( - role="system", content=prompts.main_system + prompts.system_reminder - ), - ] - messages += self.done_messages - messages += self.get_files_messages() - messages += self.cur_messages - - # self.show_messages(messages, "all") - - content = self.send(messages) - - self.cur_messages += [ - dict(role="assistant", content=content), - ] - - self.console.print() - try: - edited = self.update_files(content, inp) - except Exception as err: - print(err) - print() - traceback.print_exc() - edited = None + self.run_loop() + except KeyboardInterrupt: + self.num_control_c += 1 + if self.num_control_c >= 2: + break + self.console.print("[bold red]^C again to quit") - if not edited: - continue + if self.pretty: + print(Style.RESET_ALL) - self.check_for_local_edits(True) + def run_loop(self): + inp = self.get_input() + if inp is None: + return + + self.num_control_c = 0 + + if self.check_for_local_edits(): + # files changed, move cur messages back behind the files messages self.done_messages += self.cur_messages self.done_messages += [ - dict(role="user", content=prompts.files_content_gpt_edits), + dict(role="user", content=prompts.files_content_local_edits), dict(role="assistant", content="Ok."), ] self.cur_messages = [] + self.cur_messages += [ + dict(role="user", content=inp), + ] + + messages = [ + dict(role="system", content=prompts.main_system + prompts.system_reminder), + ] + messages += self.done_messages + messages += self.get_files_messages() + messages += self.cur_messages + + self.show_messages(messages, "all") + + content, interrupted = self.send(messages) + if interrupted: + content += "\n^C KeyboardInterrupt" + + self.cur_messages += [ + dict(role="assistant", content=content), + ] + + self.console.print() + if interrupted: + return True + + try: + edited = self.update_files(content, inp) + except Exception as err: + print(err) + print() + traceback.print_exc() + edited = None + + if not edited: + return True + + self.check_for_local_edits(True) + self.done_messages += self.cur_messages + self.done_messages += [ + dict(role="user", content=prompts.files_content_gpt_edits), + dict(role="assistant", content="Ok."), + ] + self.cur_messages = [] + return True + def show_messages(self, messages, title): print(title.upper(), "*" * 50) @@ -229,20 +237,26 @@ class Coder: stream=True, ) - if show_progress: - return self.show_send_progress(completion, show_progress) - elif self.pretty: - return self.show_send_output_color(completion) - else: - return self.show_send_output_plain(completion) + interrupted = False + try: + if show_progress: + self.show_send_progress(completion, show_progress) + elif self.pretty: + self.show_send_output_color(completion) + else: + self.show_send_output_plain(completion) + except KeyboardInterrupt: + interrupted = True + + return self.resp, interrupted def show_send_progress(self, completion, show_progress): - resp = [] + self.resp = "" pbar = tqdm(total=show_progress) for chunk in completion: try: text = chunk.choices[0].delta.content - resp.append(text) + self.resp += text except AttributeError: continue @@ -251,50 +265,43 @@ class Coder: pbar.update(show_progress) pbar.close() - resp = "".join(resp) - return resp - def show_send_output_plain(self, completion): - resp = "" + self.resp = "" for chunk in completion: if chunk.choices[0].finish_reason not in (None, "stop"): dump(chunk.choices[0].finish_reason) try: text = chunk.choices[0].delta.content - resp += text + self.resp += text except AttributeError: continue sys.stdout.write(text) sys.stdout.flush() - return resp - def show_send_output_color(self, completion): - resp = "" + self.resp = "" with Live(vertical_overflow="scroll") as live: for chunk in completion: if chunk.choices[0].finish_reason not in (None, "stop"): - dump(chunk.choices[0].finish_reason) + assert False, "Exceeded context window!" try: text = chunk.choices[0].delta.content - resp += text + self.resp += text except AttributeError: continue - md = Markdown(resp, style="blue", code_theme="default") + md = Markdown(self.resp, style="blue", code_theme="default") live.update(md) live.update(Text("")) live.stop() - md = Markdown(resp, style="blue", code_theme="default") + md = Markdown(self.resp, style="blue", code_theme="default") self.console.print(md) - return resp - pattern = re.compile( r"(\S+)\s+(```\s*)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED", re.MULTILINE | re.DOTALL, @@ -363,9 +370,12 @@ class Coder: dict(role="system", content=prompts.editor_system), dict(role="user", content=prompt), ] - res = self.send( + res, interrupted = self.send( messages, show_progress=len(content) + len(edit) / 2, model=model ) + if interrupted: + return + res = self.strip_quoted_wrapping(res, fname) fname.write_text(res)