Gracefully handle ^C and let gpt see that the user interrupted it

This commit is contained in:
Paul Gauthier 2023-05-08 08:49:28 -07:00
parent e1fd8b0edc
commit 8ac27839df

View file

@ -89,7 +89,6 @@ class Coder:
print() print()
inp = "" inp = ""
num_control_c = 0
if self.pretty: if self.pretty:
print(Fore.GREEN, end="\r") print(Fore.GREEN, end="\r")
else: else:
@ -100,12 +99,6 @@ class Coder:
inp = input("> ") inp = input("> ")
except EOFError: except EOFError:
return return
except KeyboardInterrupt:
num_control_c += 1
self.console.print()
if num_control_c >= 2:
return
self.console.print("[bold red]^C again to quit")
### ###
if self.pretty: if self.pretty:
@ -145,11 +138,27 @@ class Coder:
self.done_messages = [] self.done_messages = []
self.cur_messages = [] self.cur_messages = []
self.num_control_c = 0
while True: while True:
try:
self.run_loop()
except KeyboardInterrupt:
self.num_control_c += 1
if self.num_control_c >= 2:
break
self.console.print("[bold red]^C again to quit")
if self.pretty:
print(Style.RESET_ALL)
def run_loop(self):
inp = self.get_input() inp = self.get_input()
if inp is None: if inp is None:
return return
self.num_control_c = 0
if self.check_for_local_edits(): if self.check_for_local_edits():
# files changed, move cur messages back behind the files messages # files changed, move cur messages back behind the files messages
self.done_messages += self.cur_messages self.done_messages += self.cur_messages
@ -163,28 +172,26 @@ class Coder:
dict(role="user", content=inp), dict(role="user", content=inp),
] ]
# self.show_messages(self.done_messages, "done")
# self.show_messages(self.files_messages, "files")
# self.show_messages(self.cur_messages, "cur")
messages = [ messages = [
dict( dict(role="system", content=prompts.main_system + prompts.system_reminder),
role="system", content=prompts.main_system + prompts.system_reminder
),
] ]
messages += self.done_messages messages += self.done_messages
messages += self.get_files_messages() messages += self.get_files_messages()
messages += self.cur_messages messages += self.cur_messages
# self.show_messages(messages, "all") self.show_messages(messages, "all")
content = self.send(messages) content, interrupted = self.send(messages)
if interrupted:
content += "\n^C KeyboardInterrupt"
self.cur_messages += [ self.cur_messages += [
dict(role="assistant", content=content), dict(role="assistant", content=content),
] ]
self.console.print() self.console.print()
if interrupted:
return True
try: try:
edited = self.update_files(content, inp) edited = self.update_files(content, inp)
@ -195,7 +202,7 @@ class Coder:
edited = None edited = None
if not edited: if not edited:
continue return True
self.check_for_local_edits(True) self.check_for_local_edits(True)
self.done_messages += self.cur_messages self.done_messages += self.cur_messages
@ -204,6 +211,7 @@ class Coder:
dict(role="assistant", content="Ok."), dict(role="assistant", content="Ok."),
] ]
self.cur_messages = [] self.cur_messages = []
return True
def show_messages(self, messages, title): def show_messages(self, messages, title):
print(title.upper(), "*" * 50) print(title.upper(), "*" * 50)
@ -229,20 +237,26 @@ class Coder:
stream=True, stream=True,
) )
interrupted = False
try:
if show_progress: if show_progress:
return self.show_send_progress(completion, show_progress) self.show_send_progress(completion, show_progress)
elif self.pretty: elif self.pretty:
return self.show_send_output_color(completion) self.show_send_output_color(completion)
else: else:
return self.show_send_output_plain(completion) self.show_send_output_plain(completion)
except KeyboardInterrupt:
interrupted = True
return self.resp, interrupted
def show_send_progress(self, completion, show_progress): def show_send_progress(self, completion, show_progress):
resp = [] self.resp = ""
pbar = tqdm(total=show_progress) pbar = tqdm(total=show_progress)
for chunk in completion: for chunk in completion:
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
resp.append(text) self.resp += text
except AttributeError: except AttributeError:
continue continue
@ -251,50 +265,43 @@ class Coder:
pbar.update(show_progress) pbar.update(show_progress)
pbar.close() pbar.close()
resp = "".join(resp)
return resp
def show_send_output_plain(self, completion): def show_send_output_plain(self, completion):
resp = "" self.resp = ""
for chunk in completion: for chunk in completion:
if chunk.choices[0].finish_reason not in (None, "stop"): if chunk.choices[0].finish_reason not in (None, "stop"):
dump(chunk.choices[0].finish_reason) dump(chunk.choices[0].finish_reason)
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
resp += text self.resp += text
except AttributeError: except AttributeError:
continue continue
sys.stdout.write(text) sys.stdout.write(text)
sys.stdout.flush() sys.stdout.flush()
return resp
def show_send_output_color(self, completion): def show_send_output_color(self, completion):
resp = "" self.resp = ""
with Live(vertical_overflow="scroll") as live: with Live(vertical_overflow="scroll") as live:
for chunk in completion: for chunk in completion:
if chunk.choices[0].finish_reason not in (None, "stop"): if chunk.choices[0].finish_reason not in (None, "stop"):
dump(chunk.choices[0].finish_reason) assert False, "Exceeded context window!"
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
resp += text self.resp += text
except AttributeError: except AttributeError:
continue continue
md = Markdown(resp, style="blue", code_theme="default") md = Markdown(self.resp, style="blue", code_theme="default")
live.update(md) live.update(md)
live.update(Text("")) live.update(Text(""))
live.stop() live.stop()
md = Markdown(resp, style="blue", code_theme="default") md = Markdown(self.resp, style="blue", code_theme="default")
self.console.print(md) self.console.print(md)
return resp
pattern = re.compile( pattern = re.compile(
r"(\S+)\s+(```\s*)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED", r"(\S+)\s+(```\s*)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED",
re.MULTILINE | re.DOTALL, re.MULTILINE | re.DOTALL,
@ -363,9 +370,12 @@ class Coder:
dict(role="system", content=prompts.editor_system), dict(role="system", content=prompts.editor_system),
dict(role="user", content=prompt), dict(role="user", content=prompt),
] ]
res = self.send( res, interrupted = self.send(
messages, show_progress=len(content) + len(edit) / 2, model=model messages, show_progress=len(content) + len(edit) / 2, model=model
) )
if interrupted:
return
res = self.strip_quoted_wrapping(res, fname) res = self.strip_quoted_wrapping(res, fname)
fname.write_text(res) fname.write_text(res)