Gracefully handle ^C and let gpt see that the user interrupted it

This commit is contained in:
Paul Gauthier 2023-05-08 08:49:28 -07:00
parent e1fd8b0edc
commit 8ac27839df

170
coder.py
View file

@ -89,7 +89,6 @@ class Coder:
print()
inp = ""
num_control_c = 0
if self.pretty:
print(Fore.GREEN, end="\r")
else:
@ -100,12 +99,6 @@ class Coder:
inp = input("> ")
except EOFError:
return
except KeyboardInterrupt:
num_control_c += 1
self.console.print()
if num_control_c >= 2:
return
self.console.print("[bold red]^C again to quit")
###
if self.pretty:
@ -145,66 +138,81 @@ class Coder:
self.done_messages = []
self.cur_messages = []
self.num_control_c = 0
while True:
inp = self.get_input()
if inp is None:
return
if self.check_for_local_edits():
# files changed, move cur messages back behind the files messages
self.done_messages += self.cur_messages
self.done_messages += [
dict(role="user", content=prompts.files_content_local_edits),
dict(role="assistant", content="Ok."),
]
self.cur_messages = []
self.cur_messages += [
dict(role="user", content=inp),
]
# self.show_messages(self.done_messages, "done")
# self.show_messages(self.files_messages, "files")
# self.show_messages(self.cur_messages, "cur")
messages = [
dict(
role="system", content=prompts.main_system + prompts.system_reminder
),
]
messages += self.done_messages
messages += self.get_files_messages()
messages += self.cur_messages
# self.show_messages(messages, "all")
content = self.send(messages)
self.cur_messages += [
dict(role="assistant", content=content),
]
self.console.print()
try:
edited = self.update_files(content, inp)
except Exception as err:
print(err)
print()
traceback.print_exc()
edited = None
self.run_loop()
except KeyboardInterrupt:
self.num_control_c += 1
if self.num_control_c >= 2:
break
self.console.print("[bold red]^C again to quit")
if not edited:
continue
if self.pretty:
print(Style.RESET_ALL)
self.check_for_local_edits(True)
def run_loop(self):
inp = self.get_input()
if inp is None:
return
self.num_control_c = 0
if self.check_for_local_edits():
# files changed, move cur messages back behind the files messages
self.done_messages += self.cur_messages
self.done_messages += [
dict(role="user", content=prompts.files_content_gpt_edits),
dict(role="user", content=prompts.files_content_local_edits),
dict(role="assistant", content="Ok."),
]
self.cur_messages = []
self.cur_messages += [
dict(role="user", content=inp),
]
messages = [
dict(role="system", content=prompts.main_system + prompts.system_reminder),
]
messages += self.done_messages
messages += self.get_files_messages()
messages += self.cur_messages
self.show_messages(messages, "all")
content, interrupted = self.send(messages)
if interrupted:
content += "\n^C KeyboardInterrupt"
self.cur_messages += [
dict(role="assistant", content=content),
]
self.console.print()
if interrupted:
return True
try:
edited = self.update_files(content, inp)
except Exception as err:
print(err)
print()
traceback.print_exc()
edited = None
if not edited:
return True
self.check_for_local_edits(True)
self.done_messages += self.cur_messages
self.done_messages += [
dict(role="user", content=prompts.files_content_gpt_edits),
dict(role="assistant", content="Ok."),
]
self.cur_messages = []
return True
def show_messages(self, messages, title):
print(title.upper(), "*" * 50)
@ -229,20 +237,26 @@ class Coder:
stream=True,
)
if show_progress:
return self.show_send_progress(completion, show_progress)
elif self.pretty:
return self.show_send_output_color(completion)
else:
return self.show_send_output_plain(completion)
interrupted = False
try:
if show_progress:
self.show_send_progress(completion, show_progress)
elif self.pretty:
self.show_send_output_color(completion)
else:
self.show_send_output_plain(completion)
except KeyboardInterrupt:
interrupted = True
return self.resp, interrupted
def show_send_progress(self, completion, show_progress):
resp = []
self.resp = ""
pbar = tqdm(total=show_progress)
for chunk in completion:
try:
text = chunk.choices[0].delta.content
resp.append(text)
self.resp += text
except AttributeError:
continue
@ -251,50 +265,43 @@ class Coder:
pbar.update(show_progress)
pbar.close()
resp = "".join(resp)
return resp
def show_send_output_plain(self, completion):
resp = ""
self.resp = ""
for chunk in completion:
if chunk.choices[0].finish_reason not in (None, "stop"):
dump(chunk.choices[0].finish_reason)
try:
text = chunk.choices[0].delta.content
resp += text
self.resp += text
except AttributeError:
continue
sys.stdout.write(text)
sys.stdout.flush()
return resp
def show_send_output_color(self, completion):
resp = ""
self.resp = ""
with Live(vertical_overflow="scroll") as live:
for chunk in completion:
if chunk.choices[0].finish_reason not in (None, "stop"):
dump(chunk.choices[0].finish_reason)
assert False, "Exceeded context window!"
try:
text = chunk.choices[0].delta.content
resp += text
self.resp += text
except AttributeError:
continue
md = Markdown(resp, style="blue", code_theme="default")
md = Markdown(self.resp, style="blue", code_theme="default")
live.update(md)
live.update(Text(""))
live.stop()
md = Markdown(resp, style="blue", code_theme="default")
md = Markdown(self.resp, style="blue", code_theme="default")
self.console.print(md)
return resp
pattern = re.compile(
r"(\S+)\s+(```\s*)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED",
re.MULTILINE | re.DOTALL,
@ -363,9 +370,12 @@ class Coder:
dict(role="system", content=prompts.editor_system),
dict(role="user", content=prompt),
]
res = self.send(
res, interrupted = self.send(
messages, show_progress=len(content) + len(edit) / 2, model=model
)
if interrupted:
return
res = self.strip_quoted_wrapping(res, fname)
fname.write_text(res)