state as a resource

This commit is contained in:
Paul Gauthier 2024-04-27 13:40:00 -07:00
parent 1c43722221
commit b4d1a0a2a2

View file

@ -11,37 +11,6 @@ from aider.dump import dump # noqa: F401
from aider.main import main as cli_main
def init_state(key, val=None):
if key in st.session_state:
return
setattr(st.session_state, key, val)
def recent_msgs():
msgs = [
"write a python program that shows off some python features",
"write a tsx program that shows off some language features",
"refactor the Frobulator.simplify method to be a stand alone function",
"lorem ipsum dolor",
"lorem adipiscing adipiscing et dolore sit elit aliqua dolore ut incididunt",
(
"sed magna consectetur et quis do magna labore ad elit et elit ad eiusmod sed labore"
" aliqua eiusmod enim ad nostrud\n\namet consectetur magna tempor do enim aliqua enim"
" tempor adipiscing sit et"
),
]
# msgs = 30 * msgs
return st.selectbox(
"N/A",
msgs,
placeholder="Resend recent chat message",
label_visibility="collapsed",
index=None,
key=f"recent_msgs_{st.session_state.recent_msgs_num}",
)
def search(text=None):
results = []
for root, _, files in os.walk("aider"):
@ -54,16 +23,24 @@ def search(text=None):
return results
@st.cache_data
def get_nominal_cached_data():
print("get_nominal_cached_data")
# Keep state as a resource, which survives browser reloads (since Coder does too)
class State:
keys = set()
def init(self, key, val=None):
if key in self.keys:
return
self.keys.add(key)
setattr(self, key, val)
@st.cache_resource
def get_state():
return State()
@st.cache_resource
def get_coder():
print("get_coder")
get_nominal_cached_data()
coder = cli_main(return_coder=True)
if not isinstance(coder, Coder):
raise ValueError(coder)
@ -232,7 +209,7 @@ class GUI:
def reset_recent_msgs(self):
self.recent_msgs_empty.empty()
with self.recent_msgs_empty:
self.old_prompt = recent_msgs()
self.old_prompt = self.recent_msgs()
def do_messages_container(self):
self.messages = st.container()
@ -243,7 +220,7 @@ class GUI:
with self.messages:
self.announce()
for msg in st.session_state.messages:
for msg in self.state.messages:
dump(msg)
role = msg["role"]
@ -261,18 +238,22 @@ class GUI:
def initialize_state(self):
messages = [{"role": "assistant", "content": "How can I help you?"}]
init_state("messages", messages)
init_state("recent_msgs_num", 0)
init_state("last_aider_commit_hash", self.coder.last_aider_commit_hash)
init_state("prompt")
self.state.init("messages", messages)
self.state.init("recent_msgs_num", 0)
self.state.init("last_aider_commit_hash", self.coder.last_aider_commit_hash)
self.state.init("prompt")
dump(self.state.messages)
def button(self, args, **kwargs):
"Create a button, disabled if prompt pending"
kwargs["disabled"] = self.prompt_pending()
return st.button(args, **kwargs)
def __init__(self, coder):
def __init__(self, coder, state):
self.coder = coder
self.state = state
self.last_undo_button = None
# Force the coder to cooperate, regardless of cmd line args
@ -298,22 +279,22 @@ class GUI:
if self.old_prompt:
prompt = self.old_prompt
st.session_state.recent_msgs_num += 1
self.state.recent_msgs_num += 1
self.reset_recent_msgs()
self.chat(prompt)
return
def prompt_pending(self):
return st.session_state.prompt is not None
return self.state.prompt is not None
def cost(self):
cost = random.random() * 0.003 + 0.001
st.caption(f"${cost:0.4f}")
def chat(self, prompt):
st.session_state.prompt = prompt
self.state.prompt = prompt
st.session_state.messages.append({"role": "user", "content": prompt})
self.state.messages.append({"role": "user", "content": prompt})
with self.messages.chat_message("user"):
st.write(prompt)
@ -321,17 +302,17 @@ class GUI:
st.experimental_rerun()
def process_chat(self):
prompt = st.session_state.prompt
st.session_state.prompt = None
prompt = self.state.prompt
self.state.prompt = None
while prompt:
with self.messages.chat_message("assistant"):
res = st.write_stream(self.coder.run_stream(prompt))
st.session_state.messages.append({"role": "assistant", "content": res})
self.state.messages.append({"role": "assistant", "content": res})
# self.cost()
if self.coder.reflected_message:
info = dict(role="info", message=self.coder.reflected_message)
st.session_state.messages.append(info)
self.state.messages.append(info)
self.messages.info(self.coder.reflected_message)
prompt = self.coder.reflected_message
@ -340,7 +321,7 @@ class GUI:
role="edit",
fnames=self.coder.aider_edited_files,
)
if st.session_state.last_aider_commit_hash != self.coder.last_aider_commit_hash:
if self.state.last_aider_commit_hash != self.coder.last_aider_commit_hash:
edit["commit_hash"] = self.coder.last_aider_commit_hash
edit["commit_message"] = self.coder.last_aider_commit_message
commits = f"{self.coder.last_aider_commit_hash}~1"
@ -350,18 +331,43 @@ class GUI:
self.coder.last_aider_commit_hash,
)
edit["diff"] = diff
st.session_state.last_aider_commit_hash = self.coder.last_aider_commit_hash
self.state.last_aider_commit_hash = self.coder.last_aider_commit_hash
st.session_state.messages.append(edit)
self.state.messages.append(edit)
self.show_edit_info(edit)
# re-render the UI for the non-prompt_pending state
st.experimental_rerun()
def recent_msgs(self):
msgs = [
"write a python program that shows off some python features",
"write a tsx program that shows off some language features",
"refactor the Frobulator.simplify method to be a stand alone function",
"lorem ipsum dolor",
"lorem adipiscing adipiscing et dolore sit elit aliqua dolore ut incididunt",
(
"sed magna consectetur et quis do magna labore ad elit et elit ad eiusmod sed"
" labore aliqua eiusmod enim ad nostrud\n\namet consectetur magna tempor do enim"
" aliqua enim tempor adipiscing sit et"
),
]
# msgs = 30 * msgs
return st.selectbox(
"N/A",
msgs,
placeholder="Resend recent chat message",
label_visibility="collapsed",
index=None,
key=f"recent_msgs_{self.state.recent_msgs_num}",
)
def gui_main():
coder = get_coder()
GUI(coder)
state = get_state()
GUI(coder, state)
if __name__ == "__main__":