restructured reflected messages

This commit is contained in:
Paul Gauthier 2024-04-26 16:02:45 -07:00
parent 15e6651e09
commit 86d7705166
2 changed files with 24 additions and 25 deletions

View file

@ -37,11 +37,6 @@ class ExhaustedContextWindow(Exception):
pass pass
class ReflectMessage(Exception):
def __init__(self, message):
super().__init__(message)
def wrap_fence(name): def wrap_fence(name):
return f"<{name}>", f"</{name}>" return f"<{name}>", f"</{name}>"
@ -410,10 +405,13 @@ class Coder:
return {"role": "user", "content": image_messages} return {"role": "user", "content": image_messages}
def run(self, with_message=None): def run_stream(self, user_message):
list(self.run_stream(with_message)) self.io.user_input(user_message)
self.reflected_message = None
for chunk in self.send_new_user_message(user_message):
yield chunk
def run_stream(self, with_message=None): def run(self, with_message=None):
while True: while True:
try: try:
if with_message: if with_message:
@ -423,12 +421,9 @@ class Coder:
new_user_message = self.run_loop() new_user_message = self.run_loop()
while new_user_message: while new_user_message:
try: self.reflected_message = None
for chunk in self.send_new_user_message(new_user_message): list(self.send_new_user_message(new_user_message))
yield chunk new_user_message = self.reflected_message
new_user_message = None
except ReflectMessage as msg:
new_user_message = str(msg)
if with_message: if with_message:
return self.partial_response_content return self.partial_response_content
@ -508,11 +503,7 @@ class Coder:
return self.commands.run(inp) return self.commands.run(inp)
self.check_for_file_mentions(inp) self.check_for_file_mentions(inp)
return inp
try:
list(self.send_new_user_message(inp))
except ReflectMessage as msg:
return str(msg)
def fmt_system_prompt(self, prompt): def fmt_system_prompt(self, prompt):
prompt = prompt.format(fence=self.fence) prompt = prompt.format(fence=self.fence)
@ -609,7 +600,7 @@ class Coder:
edited, edit_error = self.apply_updates() edited, edit_error = self.apply_updates()
if edit_error: if edit_error:
self.update_cur_messages(set()) self.update_cur_messages(set())
raise ReflectMessage(edit_error) self.reflected_message = edit_error
self.update_cur_messages(edited) self.update_cur_messages(edited)
@ -625,7 +616,7 @@ class Coder:
add_rel_files_message = self.check_for_file_mentions(content) add_rel_files_message = self.check_for_file_mentions(content)
if add_rel_files_message: if add_rel_files_message:
raise ReflectMessage(add_rel_files_message) self.reflected_message = add_rel_files_message
def update_cur_messages(self, edited): def update_cur_messages(self, edited):
if self.partial_response_content: if self.partial_response_content:

View file

@ -263,6 +263,10 @@ class GUI:
self.reset_recent_msgs() self.reset_recent_msgs()
self.chat(prompt) self.chat(prompt)
def cost(self):
cost = random.random() * 0.003 + 0.001
st.caption(f"${cost:0.4f}")
def chat(self, prompt): def chat(self, prompt):
self.clear_controls() self.clear_controls()
@ -270,10 +274,14 @@ class GUI:
with self.messages.chat_message("user"): with self.messages.chat_message("user"):
st.write(prompt) st.write(prompt)
with self.messages.chat_message("assistant"): while prompt:
res = st.write(self.coder.run_stream(prompt)) with self.messages.chat_message("assistant"):
cost = random.random() * 0.003 + 0.001 res = st.write(self.coder.run_stream(prompt))
st.caption(f"${cost:0.4f}") self.cost()
dump(self.coder.reflected_message)
if self.coder.reflected_message:
self.messages.info(self.coder.reflected_message)
prompt = self.coder.reflected_message
st.session_state.messages.append({"role": "assistant", "content": res}) st.session_state.messages.append({"role": "assistant", "content": res})