diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 3bb929216..4b7d0faf4 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -37,6 +37,11 @@ class ExhaustedContextWindow(Exception): pass +class ReflectMessage(Exception): + def __init__(self, message): + super().__init__(message) + + def wrap_fence(name): return f"<{name}>", f"" @@ -54,6 +59,7 @@ class Coder: last_keyboard_interrupt = None max_apply_update_errors = 3 edit_format = None + yield_stream = False @classmethod def create( @@ -405,6 +411,9 @@ class Coder: return {"role": "user", "content": image_messages} def run(self, with_message=None): + list(self.run_stream(with_message)) + + def run_stream(self, with_message=None): while True: try: if with_message: @@ -414,7 +423,12 @@ class Coder: new_user_message = self.run_loop() while new_user_message: - new_user_message = self.send_new_user_message(new_user_message) + try: + for chunk in self.send_new_user_message(new_user_message): + yield chunk + new_user_message = None + except ReflectMessage as msg: + new_user_message = str(msg) if with_message: return self.partial_response_content @@ -495,7 +509,10 @@ class Coder: self.check_for_file_mentions(inp) - return self.send_new_user_message(inp) + try: + list(self.send_new_user_message(inp)) + except ReflectMessage as msg: + return str(msg) def fmt_system_prompt(self, prompt): prompt = prompt.format(fence=self.fence) @@ -550,7 +567,10 @@ class Coder: exhausted = False interrupted = False try: - interrupted = self.send(messages, functions=self.functions) + for chunk in self.send(messages, functions=self.functions): + yield chunk + except KeyboardInterrupt: + interrupted = True except ExhaustedContextWindow: exhausted = True except openai.BadRequestError as err: @@ -579,18 +599,17 @@ class Coder: else: content = "" + self.io.tool_output() + if interrupted: content += "\n^C KeyboardInterrupt" - - self.io.tool_output() - if interrupted: self.cur_messages += [dict(role="assistant", content=content)] return edited, edit_error = self.apply_updates() if edit_error: self.update_cur_messages(set()) - return edit_error + raise ReflectMessage(edit_error) self.update_cur_messages(edited) @@ -606,7 +625,7 @@ class Coder: add_rel_files_message = self.check_for_file_mentions(content) if add_rel_files_message: - return add_rel_files_message + raise ReflectMessage(add_rel_files_message) def update_cur_messages(self, edited): if self.partial_response_content: @@ -674,7 +693,8 @@ class Coder: self.chat_completion_call_hashes.append(hash_object.hexdigest()) if self.stream: - self.show_send_output_stream(completion) + for chunk in self.show_send_output_stream(completion): + yield chunk else: self.show_send_output(completion) except KeyboardInterrupt: @@ -689,7 +709,8 @@ class Coder: if args: self.io.ai_output(json.dumps(args, indent=4)) - return interrupted + if interrupted: + raise KeyboardInterrupt def show_send_output(self, completion): if self.verbose: @@ -790,6 +811,7 @@ class Coder: elif text: sys.stdout.write(text) sys.stdout.flush() + yield text finally: if mdstream: self.live_incremental_response(mdstream, True) diff --git a/aider/gui.py b/aider/gui.py index 4c91bcffe..788805acb 100755 --- a/aider/gui.py +++ b/aider/gui.py @@ -270,16 +270,13 @@ class GUI: with self.messages.chat_message("user"): st.write(prompt) - res = self.coder.run(prompt) - # self.coder.io.user_input(with_message) - - st.session_state.messages.append({"role": "assistant", "content": res}) - with self.messages.chat_message("assistant"): - st.write(res) + res = st.write(self.coder.run_stream(prompt)) cost = random.random() * 0.003 + 0.001 st.caption(f"${cost:0.4f}") + st.session_state.messages.append({"role": "assistant", "content": res}) + with self.messages: self.mock_tool_output()