mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 06:15:00 +00:00
Gracefully handle ^C and let gpt see that the user interrupted it
This commit is contained in:
parent
e1fd8b0edc
commit
8ac27839df
1 changed files with 90 additions and 80 deletions
170
coder.py
170
coder.py
|
@ -89,7 +89,6 @@ class Coder:
|
|||
print()
|
||||
|
||||
inp = ""
|
||||
num_control_c = 0
|
||||
if self.pretty:
|
||||
print(Fore.GREEN, end="\r")
|
||||
else:
|
||||
|
@ -100,12 +99,6 @@ class Coder:
|
|||
inp = input("> ")
|
||||
except EOFError:
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
num_control_c += 1
|
||||
self.console.print()
|
||||
if num_control_c >= 2:
|
||||
return
|
||||
self.console.print("[bold red]^C again to quit")
|
||||
|
||||
###
|
||||
if self.pretty:
|
||||
|
@ -145,66 +138,81 @@ class Coder:
|
|||
self.done_messages = []
|
||||
self.cur_messages = []
|
||||
|
||||
self.num_control_c = 0
|
||||
|
||||
while True:
|
||||
inp = self.get_input()
|
||||
if inp is None:
|
||||
return
|
||||
|
||||
if self.check_for_local_edits():
|
||||
# files changed, move cur messages back behind the files messages
|
||||
self.done_messages += self.cur_messages
|
||||
self.done_messages += [
|
||||
dict(role="user", content=prompts.files_content_local_edits),
|
||||
dict(role="assistant", content="Ok."),
|
||||
]
|
||||
self.cur_messages = []
|
||||
|
||||
self.cur_messages += [
|
||||
dict(role="user", content=inp),
|
||||
]
|
||||
|
||||
# self.show_messages(self.done_messages, "done")
|
||||
# self.show_messages(self.files_messages, "files")
|
||||
# self.show_messages(self.cur_messages, "cur")
|
||||
|
||||
messages = [
|
||||
dict(
|
||||
role="system", content=prompts.main_system + prompts.system_reminder
|
||||
),
|
||||
]
|
||||
messages += self.done_messages
|
||||
messages += self.get_files_messages()
|
||||
messages += self.cur_messages
|
||||
|
||||
# self.show_messages(messages, "all")
|
||||
|
||||
content = self.send(messages)
|
||||
|
||||
self.cur_messages += [
|
||||
dict(role="assistant", content=content),
|
||||
]
|
||||
|
||||
self.console.print()
|
||||
|
||||
try:
|
||||
edited = self.update_files(content, inp)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
print()
|
||||
traceback.print_exc()
|
||||
edited = None
|
||||
self.run_loop()
|
||||
except KeyboardInterrupt:
|
||||
self.num_control_c += 1
|
||||
if self.num_control_c >= 2:
|
||||
break
|
||||
self.console.print("[bold red]^C again to quit")
|
||||
|
||||
if not edited:
|
||||
continue
|
||||
if self.pretty:
|
||||
print(Style.RESET_ALL)
|
||||
|
||||
self.check_for_local_edits(True)
|
||||
def run_loop(self):
|
||||
inp = self.get_input()
|
||||
if inp is None:
|
||||
return
|
||||
|
||||
self.num_control_c = 0
|
||||
|
||||
if self.check_for_local_edits():
|
||||
# files changed, move cur messages back behind the files messages
|
||||
self.done_messages += self.cur_messages
|
||||
self.done_messages += [
|
||||
dict(role="user", content=prompts.files_content_gpt_edits),
|
||||
dict(role="user", content=prompts.files_content_local_edits),
|
||||
dict(role="assistant", content="Ok."),
|
||||
]
|
||||
self.cur_messages = []
|
||||
|
||||
self.cur_messages += [
|
||||
dict(role="user", content=inp),
|
||||
]
|
||||
|
||||
messages = [
|
||||
dict(role="system", content=prompts.main_system + prompts.system_reminder),
|
||||
]
|
||||
messages += self.done_messages
|
||||
messages += self.get_files_messages()
|
||||
messages += self.cur_messages
|
||||
|
||||
self.show_messages(messages, "all")
|
||||
|
||||
content, interrupted = self.send(messages)
|
||||
if interrupted:
|
||||
content += "\n^C KeyboardInterrupt"
|
||||
|
||||
self.cur_messages += [
|
||||
dict(role="assistant", content=content),
|
||||
]
|
||||
|
||||
self.console.print()
|
||||
if interrupted:
|
||||
return True
|
||||
|
||||
try:
|
||||
edited = self.update_files(content, inp)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
print()
|
||||
traceback.print_exc()
|
||||
edited = None
|
||||
|
||||
if not edited:
|
||||
return True
|
||||
|
||||
self.check_for_local_edits(True)
|
||||
self.done_messages += self.cur_messages
|
||||
self.done_messages += [
|
||||
dict(role="user", content=prompts.files_content_gpt_edits),
|
||||
dict(role="assistant", content="Ok."),
|
||||
]
|
||||
self.cur_messages = []
|
||||
return True
|
||||
|
||||
def show_messages(self, messages, title):
|
||||
print(title.upper(), "*" * 50)
|
||||
|
||||
|
@ -229,20 +237,26 @@ class Coder:
|
|||
stream=True,
|
||||
)
|
||||
|
||||
if show_progress:
|
||||
return self.show_send_progress(completion, show_progress)
|
||||
elif self.pretty:
|
||||
return self.show_send_output_color(completion)
|
||||
else:
|
||||
return self.show_send_output_plain(completion)
|
||||
interrupted = False
|
||||
try:
|
||||
if show_progress:
|
||||
self.show_send_progress(completion, show_progress)
|
||||
elif self.pretty:
|
||||
self.show_send_output_color(completion)
|
||||
else:
|
||||
self.show_send_output_plain(completion)
|
||||
except KeyboardInterrupt:
|
||||
interrupted = True
|
||||
|
||||
return self.resp, interrupted
|
||||
|
||||
def show_send_progress(self, completion, show_progress):
|
||||
resp = []
|
||||
self.resp = ""
|
||||
pbar = tqdm(total=show_progress)
|
||||
for chunk in completion:
|
||||
try:
|
||||
text = chunk.choices[0].delta.content
|
||||
resp.append(text)
|
||||
self.resp += text
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
|
@ -251,50 +265,43 @@ class Coder:
|
|||
pbar.update(show_progress)
|
||||
pbar.close()
|
||||
|
||||
resp = "".join(resp)
|
||||
return resp
|
||||
|
||||
def show_send_output_plain(self, completion):
|
||||
resp = ""
|
||||
self.resp = ""
|
||||
|
||||
for chunk in completion:
|
||||
if chunk.choices[0].finish_reason not in (None, "stop"):
|
||||
dump(chunk.choices[0].finish_reason)
|
||||
try:
|
||||
text = chunk.choices[0].delta.content
|
||||
resp += text
|
||||
self.resp += text
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
sys.stdout.write(text)
|
||||
sys.stdout.flush()
|
||||
|
||||
return resp
|
||||
|
||||
def show_send_output_color(self, completion):
|
||||
resp = ""
|
||||
self.resp = ""
|
||||
|
||||
with Live(vertical_overflow="scroll") as live:
|
||||
for chunk in completion:
|
||||
if chunk.choices[0].finish_reason not in (None, "stop"):
|
||||
dump(chunk.choices[0].finish_reason)
|
||||
assert False, "Exceeded context window!"
|
||||
try:
|
||||
text = chunk.choices[0].delta.content
|
||||
resp += text
|
||||
self.resp += text
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
md = Markdown(resp, style="blue", code_theme="default")
|
||||
md = Markdown(self.resp, style="blue", code_theme="default")
|
||||
live.update(md)
|
||||
|
||||
live.update(Text(""))
|
||||
live.stop()
|
||||
|
||||
md = Markdown(resp, style="blue", code_theme="default")
|
||||
md = Markdown(self.resp, style="blue", code_theme="default")
|
||||
self.console.print(md)
|
||||
|
||||
return resp
|
||||
|
||||
pattern = re.compile(
|
||||
r"(\S+)\s+(```\s*)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
|
@ -363,9 +370,12 @@ class Coder:
|
|||
dict(role="system", content=prompts.editor_system),
|
||||
dict(role="user", content=prompt),
|
||||
]
|
||||
res = self.send(
|
||||
res, interrupted = self.send(
|
||||
messages, show_progress=len(content) + len(edit) / 2, model=model
|
||||
)
|
||||
if interrupted:
|
||||
return
|
||||
|
||||
res = self.strip_quoted_wrapping(res, fname)
|
||||
fname.write_text(res)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue