stream to gui

This commit is contained in:
Paul Gauthier 2024-04-26 15:41:47 -07:00
parent 98d2997035
commit 15e6651e09
2 changed files with 35 additions and 16 deletions

View file

@ -37,6 +37,11 @@ class ExhaustedContextWindow(Exception):
pass pass
class ReflectMessage(Exception):
def __init__(self, message):
super().__init__(message)
def wrap_fence(name): def wrap_fence(name):
return f"<{name}>", f"</{name}>" return f"<{name}>", f"</{name}>"
@ -54,6 +59,7 @@ class Coder:
last_keyboard_interrupt = None last_keyboard_interrupt = None
max_apply_update_errors = 3 max_apply_update_errors = 3
edit_format = None edit_format = None
yield_stream = False
@classmethod @classmethod
def create( def create(
@ -405,6 +411,9 @@ class Coder:
return {"role": "user", "content": image_messages} return {"role": "user", "content": image_messages}
def run(self, with_message=None): def run(self, with_message=None):
list(self.run_stream(with_message))
def run_stream(self, with_message=None):
while True: while True:
try: try:
if with_message: if with_message:
@ -414,7 +423,12 @@ class Coder:
new_user_message = self.run_loop() new_user_message = self.run_loop()
while new_user_message: while new_user_message:
new_user_message = self.send_new_user_message(new_user_message) try:
for chunk in self.send_new_user_message(new_user_message):
yield chunk
new_user_message = None
except ReflectMessage as msg:
new_user_message = str(msg)
if with_message: if with_message:
return self.partial_response_content return self.partial_response_content
@ -495,7 +509,10 @@ class Coder:
self.check_for_file_mentions(inp) self.check_for_file_mentions(inp)
return self.send_new_user_message(inp) try:
list(self.send_new_user_message(inp))
except ReflectMessage as msg:
return str(msg)
def fmt_system_prompt(self, prompt): def fmt_system_prompt(self, prompt):
prompt = prompt.format(fence=self.fence) prompt = prompt.format(fence=self.fence)
@ -550,7 +567,10 @@ class Coder:
exhausted = False exhausted = False
interrupted = False interrupted = False
try: try:
interrupted = self.send(messages, functions=self.functions) for chunk in self.send(messages, functions=self.functions):
yield chunk
except KeyboardInterrupt:
interrupted = True
except ExhaustedContextWindow: except ExhaustedContextWindow:
exhausted = True exhausted = True
except openai.BadRequestError as err: except openai.BadRequestError as err:
@ -579,18 +599,17 @@ class Coder:
else: else:
content = "" content = ""
self.io.tool_output()
if interrupted: if interrupted:
content += "\n^C KeyboardInterrupt" content += "\n^C KeyboardInterrupt"
self.io.tool_output()
if interrupted:
self.cur_messages += [dict(role="assistant", content=content)] self.cur_messages += [dict(role="assistant", content=content)]
return return
edited, edit_error = self.apply_updates() edited, edit_error = self.apply_updates()
if edit_error: if edit_error:
self.update_cur_messages(set()) self.update_cur_messages(set())
return edit_error raise ReflectMessage(edit_error)
self.update_cur_messages(edited) self.update_cur_messages(edited)
@ -606,7 +625,7 @@ class Coder:
add_rel_files_message = self.check_for_file_mentions(content) add_rel_files_message = self.check_for_file_mentions(content)
if add_rel_files_message: if add_rel_files_message:
return add_rel_files_message raise ReflectMessage(add_rel_files_message)
def update_cur_messages(self, edited): def update_cur_messages(self, edited):
if self.partial_response_content: if self.partial_response_content:
@ -674,7 +693,8 @@ class Coder:
self.chat_completion_call_hashes.append(hash_object.hexdigest()) self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream: if self.stream:
self.show_send_output_stream(completion) for chunk in self.show_send_output_stream(completion):
yield chunk
else: else:
self.show_send_output(completion) self.show_send_output(completion)
except KeyboardInterrupt: except KeyboardInterrupt:
@ -689,7 +709,8 @@ class Coder:
if args: if args:
self.io.ai_output(json.dumps(args, indent=4)) self.io.ai_output(json.dumps(args, indent=4))
return interrupted if interrupted:
raise KeyboardInterrupt
def show_send_output(self, completion): def show_send_output(self, completion):
if self.verbose: if self.verbose:
@ -790,6 +811,7 @@ class Coder:
elif text: elif text:
sys.stdout.write(text) sys.stdout.write(text)
sys.stdout.flush() sys.stdout.flush()
yield text
finally: finally:
if mdstream: if mdstream:
self.live_incremental_response(mdstream, True) self.live_incremental_response(mdstream, True)

View file

@ -270,16 +270,13 @@ class GUI:
with self.messages.chat_message("user"): with self.messages.chat_message("user"):
st.write(prompt) st.write(prompt)
res = self.coder.run(prompt)
# self.coder.io.user_input(with_message)
st.session_state.messages.append({"role": "assistant", "content": res})
with self.messages.chat_message("assistant"): with self.messages.chat_message("assistant"):
st.write(res) res = st.write(self.coder.run_stream(prompt))
cost = random.random() * 0.003 + 0.001 cost = random.random() * 0.003 + 0.001
st.caption(f"${cost:0.4f}") st.caption(f"${cost:0.4f}")
st.session_state.messages.append({"role": "assistant", "content": res})
with self.messages: with self.messages:
self.mock_tool_output() self.mock_tool_output()