mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 14:25:00 +00:00
Gracefully handle ^C and let gpt see that the user interrupted it
This commit is contained in:
parent
e1fd8b0edc
commit
8ac27839df
1 changed files with 90 additions and 80 deletions
170
coder.py
170
coder.py
|
@ -89,7 +89,6 @@ class Coder:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
inp = ""
|
inp = ""
|
||||||
num_control_c = 0
|
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
print(Fore.GREEN, end="\r")
|
print(Fore.GREEN, end="\r")
|
||||||
else:
|
else:
|
||||||
|
@ -100,12 +99,6 @@ class Coder:
|
||||||
inp = input("> ")
|
inp = input("> ")
|
||||||
except EOFError:
|
except EOFError:
|
||||||
return
|
return
|
||||||
except KeyboardInterrupt:
|
|
||||||
num_control_c += 1
|
|
||||||
self.console.print()
|
|
||||||
if num_control_c >= 2:
|
|
||||||
return
|
|
||||||
self.console.print("[bold red]^C again to quit")
|
|
||||||
|
|
||||||
###
|
###
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
|
@ -145,66 +138,81 @@ class Coder:
|
||||||
self.done_messages = []
|
self.done_messages = []
|
||||||
self.cur_messages = []
|
self.cur_messages = []
|
||||||
|
|
||||||
|
self.num_control_c = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
inp = self.get_input()
|
|
||||||
if inp is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.check_for_local_edits():
|
|
||||||
# files changed, move cur messages back behind the files messages
|
|
||||||
self.done_messages += self.cur_messages
|
|
||||||
self.done_messages += [
|
|
||||||
dict(role="user", content=prompts.files_content_local_edits),
|
|
||||||
dict(role="assistant", content="Ok."),
|
|
||||||
]
|
|
||||||
self.cur_messages = []
|
|
||||||
|
|
||||||
self.cur_messages += [
|
|
||||||
dict(role="user", content=inp),
|
|
||||||
]
|
|
||||||
|
|
||||||
# self.show_messages(self.done_messages, "done")
|
|
||||||
# self.show_messages(self.files_messages, "files")
|
|
||||||
# self.show_messages(self.cur_messages, "cur")
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
dict(
|
|
||||||
role="system", content=prompts.main_system + prompts.system_reminder
|
|
||||||
),
|
|
||||||
]
|
|
||||||
messages += self.done_messages
|
|
||||||
messages += self.get_files_messages()
|
|
||||||
messages += self.cur_messages
|
|
||||||
|
|
||||||
# self.show_messages(messages, "all")
|
|
||||||
|
|
||||||
content = self.send(messages)
|
|
||||||
|
|
||||||
self.cur_messages += [
|
|
||||||
dict(role="assistant", content=content),
|
|
||||||
]
|
|
||||||
|
|
||||||
self.console.print()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
edited = self.update_files(content, inp)
|
self.run_loop()
|
||||||
except Exception as err:
|
except KeyboardInterrupt:
|
||||||
print(err)
|
self.num_control_c += 1
|
||||||
print()
|
if self.num_control_c >= 2:
|
||||||
traceback.print_exc()
|
break
|
||||||
edited = None
|
self.console.print("[bold red]^C again to quit")
|
||||||
|
|
||||||
if not edited:
|
if self.pretty:
|
||||||
continue
|
print(Style.RESET_ALL)
|
||||||
|
|
||||||
self.check_for_local_edits(True)
|
def run_loop(self):
|
||||||
|
inp = self.get_input()
|
||||||
|
if inp is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.num_control_c = 0
|
||||||
|
|
||||||
|
if self.check_for_local_edits():
|
||||||
|
# files changed, move cur messages back behind the files messages
|
||||||
self.done_messages += self.cur_messages
|
self.done_messages += self.cur_messages
|
||||||
self.done_messages += [
|
self.done_messages += [
|
||||||
dict(role="user", content=prompts.files_content_gpt_edits),
|
dict(role="user", content=prompts.files_content_local_edits),
|
||||||
dict(role="assistant", content="Ok."),
|
dict(role="assistant", content="Ok."),
|
||||||
]
|
]
|
||||||
self.cur_messages = []
|
self.cur_messages = []
|
||||||
|
|
||||||
|
self.cur_messages += [
|
||||||
|
dict(role="user", content=inp),
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
dict(role="system", content=prompts.main_system + prompts.system_reminder),
|
||||||
|
]
|
||||||
|
messages += self.done_messages
|
||||||
|
messages += self.get_files_messages()
|
||||||
|
messages += self.cur_messages
|
||||||
|
|
||||||
|
self.show_messages(messages, "all")
|
||||||
|
|
||||||
|
content, interrupted = self.send(messages)
|
||||||
|
if interrupted:
|
||||||
|
content += "\n^C KeyboardInterrupt"
|
||||||
|
|
||||||
|
self.cur_messages += [
|
||||||
|
dict(role="assistant", content=content),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.console.print()
|
||||||
|
if interrupted:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
edited = self.update_files(content, inp)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
print()
|
||||||
|
traceback.print_exc()
|
||||||
|
edited = None
|
||||||
|
|
||||||
|
if not edited:
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.check_for_local_edits(True)
|
||||||
|
self.done_messages += self.cur_messages
|
||||||
|
self.done_messages += [
|
||||||
|
dict(role="user", content=prompts.files_content_gpt_edits),
|
||||||
|
dict(role="assistant", content="Ok."),
|
||||||
|
]
|
||||||
|
self.cur_messages = []
|
||||||
|
return True
|
||||||
|
|
||||||
def show_messages(self, messages, title):
|
def show_messages(self, messages, title):
|
||||||
print(title.upper(), "*" * 50)
|
print(title.upper(), "*" * 50)
|
||||||
|
|
||||||
|
@ -229,20 +237,26 @@ class Coder:
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if show_progress:
|
interrupted = False
|
||||||
return self.show_send_progress(completion, show_progress)
|
try:
|
||||||
elif self.pretty:
|
if show_progress:
|
||||||
return self.show_send_output_color(completion)
|
self.show_send_progress(completion, show_progress)
|
||||||
else:
|
elif self.pretty:
|
||||||
return self.show_send_output_plain(completion)
|
self.show_send_output_color(completion)
|
||||||
|
else:
|
||||||
|
self.show_send_output_plain(completion)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
interrupted = True
|
||||||
|
|
||||||
|
return self.resp, interrupted
|
||||||
|
|
||||||
def show_send_progress(self, completion, show_progress):
|
def show_send_progress(self, completion, show_progress):
|
||||||
resp = []
|
self.resp = ""
|
||||||
pbar = tqdm(total=show_progress)
|
pbar = tqdm(total=show_progress)
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
try:
|
try:
|
||||||
text = chunk.choices[0].delta.content
|
text = chunk.choices[0].delta.content
|
||||||
resp.append(text)
|
self.resp += text
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -251,50 +265,43 @@ class Coder:
|
||||||
pbar.update(show_progress)
|
pbar.update(show_progress)
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
resp = "".join(resp)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def show_send_output_plain(self, completion):
|
def show_send_output_plain(self, completion):
|
||||||
resp = ""
|
self.resp = ""
|
||||||
|
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
if chunk.choices[0].finish_reason not in (None, "stop"):
|
if chunk.choices[0].finish_reason not in (None, "stop"):
|
||||||
dump(chunk.choices[0].finish_reason)
|
dump(chunk.choices[0].finish_reason)
|
||||||
try:
|
try:
|
||||||
text = chunk.choices[0].delta.content
|
text = chunk.choices[0].delta.content
|
||||||
resp += text
|
self.resp += text
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sys.stdout.write(text)
|
sys.stdout.write(text)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def show_send_output_color(self, completion):
|
def show_send_output_color(self, completion):
|
||||||
resp = ""
|
self.resp = ""
|
||||||
|
|
||||||
with Live(vertical_overflow="scroll") as live:
|
with Live(vertical_overflow="scroll") as live:
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
if chunk.choices[0].finish_reason not in (None, "stop"):
|
if chunk.choices[0].finish_reason not in (None, "stop"):
|
||||||
dump(chunk.choices[0].finish_reason)
|
assert False, "Exceeded context window!"
|
||||||
try:
|
try:
|
||||||
text = chunk.choices[0].delta.content
|
text = chunk.choices[0].delta.content
|
||||||
resp += text
|
self.resp += text
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
md = Markdown(resp, style="blue", code_theme="default")
|
md = Markdown(self.resp, style="blue", code_theme="default")
|
||||||
live.update(md)
|
live.update(md)
|
||||||
|
|
||||||
live.update(Text(""))
|
live.update(Text(""))
|
||||||
live.stop()
|
live.stop()
|
||||||
|
|
||||||
md = Markdown(resp, style="blue", code_theme="default")
|
md = Markdown(self.resp, style="blue", code_theme="default")
|
||||||
self.console.print(md)
|
self.console.print(md)
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
r"(\S+)\s+(```\s*)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED",
|
r"(\S+)\s+(```\s*)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED",
|
||||||
re.MULTILINE | re.DOTALL,
|
re.MULTILINE | re.DOTALL,
|
||||||
|
@ -363,9 +370,12 @@ class Coder:
|
||||||
dict(role="system", content=prompts.editor_system),
|
dict(role="system", content=prompts.editor_system),
|
||||||
dict(role="user", content=prompt),
|
dict(role="user", content=prompt),
|
||||||
]
|
]
|
||||||
res = self.send(
|
res, interrupted = self.send(
|
||||||
messages, show_progress=len(content) + len(edit) / 2, model=model
|
messages, show_progress=len(content) + len(edit) / 2, model=model
|
||||||
)
|
)
|
||||||
|
if interrupted:
|
||||||
|
return
|
||||||
|
|
||||||
res = self.strip_quoted_wrapping(res, fname)
|
res = self.strip_quoted_wrapping(res, fname)
|
||||||
fname.write_text(res)
|
fname.write_text(res)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue