state as a resource

This commit is contained in:
Paul Gauthier 2024-04-27 13:40:00 -07:00
parent 1c43722221
commit b4d1a0a2a2

View file

@ -11,37 +11,6 @@ from aider.dump import dump # noqa: F401
from aider.main import main as cli_main from aider.main import main as cli_main
def init_state(key, val=None):
if key in st.session_state:
return
setattr(st.session_state, key, val)
def recent_msgs():
msgs = [
"write a python program that shows off some python features",
"write a tsx program that shows off some language features",
"refactor the Frobulator.simplify method to be a stand alone function",
"lorem ipsum dolor",
"lorem adipiscing adipiscing et dolore sit elit aliqua dolore ut incididunt",
(
"sed magna consectetur et quis do magna labore ad elit et elit ad eiusmod sed labore"
" aliqua eiusmod enim ad nostrud\n\namet consectetur magna tempor do enim aliqua enim"
" tempor adipiscing sit et"
),
]
# msgs = 30 * msgs
return st.selectbox(
"N/A",
msgs,
placeholder="Resend recent chat message",
label_visibility="collapsed",
index=None,
key=f"recent_msgs_{st.session_state.recent_msgs_num}",
)
def search(text=None): def search(text=None):
results = [] results = []
for root, _, files in os.walk("aider"): for root, _, files in os.walk("aider"):
@ -54,16 +23,24 @@ def search(text=None):
return results return results
@st.cache_data # Keep state as a resource, which survives browser reloads (since Coder does too)
def get_nominal_cached_data(): class State:
print("get_nominal_cached_data") keys = set()
def init(self, key, val=None):
if key in self.keys:
return
self.keys.add(key)
setattr(self, key, val)
@st.cache_resource
def get_state():
return State()
@st.cache_resource @st.cache_resource
def get_coder(): def get_coder():
print("get_coder")
get_nominal_cached_data()
coder = cli_main(return_coder=True) coder = cli_main(return_coder=True)
if not isinstance(coder, Coder): if not isinstance(coder, Coder):
raise ValueError(coder) raise ValueError(coder)
@ -232,7 +209,7 @@ class GUI:
def reset_recent_msgs(self): def reset_recent_msgs(self):
self.recent_msgs_empty.empty() self.recent_msgs_empty.empty()
with self.recent_msgs_empty: with self.recent_msgs_empty:
self.old_prompt = recent_msgs() self.old_prompt = self.recent_msgs()
def do_messages_container(self): def do_messages_container(self):
self.messages = st.container() self.messages = st.container()
@ -243,7 +220,7 @@ class GUI:
with self.messages: with self.messages:
self.announce() self.announce()
for msg in st.session_state.messages: for msg in self.state.messages:
dump(msg) dump(msg)
role = msg["role"] role = msg["role"]
@ -261,18 +238,22 @@ class GUI:
def initialize_state(self): def initialize_state(self):
messages = [{"role": "assistant", "content": "How can I help you?"}] messages = [{"role": "assistant", "content": "How can I help you?"}]
init_state("messages", messages) self.state.init("messages", messages)
init_state("recent_msgs_num", 0) self.state.init("recent_msgs_num", 0)
init_state("last_aider_commit_hash", self.coder.last_aider_commit_hash) self.state.init("last_aider_commit_hash", self.coder.last_aider_commit_hash)
init_state("prompt") self.state.init("prompt")
dump(self.state.messages)
def button(self, args, **kwargs): def button(self, args, **kwargs):
"Create a button, disabled if prompt pending" "Create a button, disabled if prompt pending"
kwargs["disabled"] = self.prompt_pending() kwargs["disabled"] = self.prompt_pending()
return st.button(args, **kwargs) return st.button(args, **kwargs)
def __init__(self, coder): def __init__(self, coder, state):
self.coder = coder self.coder = coder
self.state = state
self.last_undo_button = None self.last_undo_button = None
# Force the coder to cooperate, regardless of cmd line args # Force the coder to cooperate, regardless of cmd line args
@ -298,22 +279,22 @@ class GUI:
if self.old_prompt: if self.old_prompt:
prompt = self.old_prompt prompt = self.old_prompt
st.session_state.recent_msgs_num += 1 self.state.recent_msgs_num += 1
self.reset_recent_msgs() self.reset_recent_msgs()
self.chat(prompt) self.chat(prompt)
return return
def prompt_pending(self): def prompt_pending(self):
return st.session_state.prompt is not None return self.state.prompt is not None
def cost(self): def cost(self):
cost = random.random() * 0.003 + 0.001 cost = random.random() * 0.003 + 0.001
st.caption(f"${cost:0.4f}") st.caption(f"${cost:0.4f}")
def chat(self, prompt): def chat(self, prompt):
st.session_state.prompt = prompt self.state.prompt = prompt
st.session_state.messages.append({"role": "user", "content": prompt}) self.state.messages.append({"role": "user", "content": prompt})
with self.messages.chat_message("user"): with self.messages.chat_message("user"):
st.write(prompt) st.write(prompt)
@ -321,17 +302,17 @@ class GUI:
st.experimental_rerun() st.experimental_rerun()
def process_chat(self): def process_chat(self):
prompt = st.session_state.prompt prompt = self.state.prompt
st.session_state.prompt = None self.state.prompt = None
while prompt: while prompt:
with self.messages.chat_message("assistant"): with self.messages.chat_message("assistant"):
res = st.write_stream(self.coder.run_stream(prompt)) res = st.write_stream(self.coder.run_stream(prompt))
st.session_state.messages.append({"role": "assistant", "content": res}) self.state.messages.append({"role": "assistant", "content": res})
# self.cost() # self.cost()
if self.coder.reflected_message: if self.coder.reflected_message:
info = dict(role="info", message=self.coder.reflected_message) info = dict(role="info", message=self.coder.reflected_message)
st.session_state.messages.append(info) self.state.messages.append(info)
self.messages.info(self.coder.reflected_message) self.messages.info(self.coder.reflected_message)
prompt = self.coder.reflected_message prompt = self.coder.reflected_message
@ -340,7 +321,7 @@ class GUI:
role="edit", role="edit",
fnames=self.coder.aider_edited_files, fnames=self.coder.aider_edited_files,
) )
if st.session_state.last_aider_commit_hash != self.coder.last_aider_commit_hash: if self.state.last_aider_commit_hash != self.coder.last_aider_commit_hash:
edit["commit_hash"] = self.coder.last_aider_commit_hash edit["commit_hash"] = self.coder.last_aider_commit_hash
edit["commit_message"] = self.coder.last_aider_commit_message edit["commit_message"] = self.coder.last_aider_commit_message
commits = f"{self.coder.last_aider_commit_hash}~1" commits = f"{self.coder.last_aider_commit_hash}~1"
@ -350,18 +331,43 @@ class GUI:
self.coder.last_aider_commit_hash, self.coder.last_aider_commit_hash,
) )
edit["diff"] = diff edit["diff"] = diff
st.session_state.last_aider_commit_hash = self.coder.last_aider_commit_hash self.state.last_aider_commit_hash = self.coder.last_aider_commit_hash
st.session_state.messages.append(edit) self.state.messages.append(edit)
self.show_edit_info(edit) self.show_edit_info(edit)
# re-render the UI for the non-prompt_pending state # re-render the UI for the non-prompt_pending state
st.experimental_rerun() st.experimental_rerun()
def recent_msgs(self):
msgs = [
"write a python program that shows off some python features",
"write a tsx program that shows off some language features",
"refactor the Frobulator.simplify method to be a stand alone function",
"lorem ipsum dolor",
"lorem adipiscing adipiscing et dolore sit elit aliqua dolore ut incididunt",
(
"sed magna consectetur et quis do magna labore ad elit et elit ad eiusmod sed"
" labore aliqua eiusmod enim ad nostrud\n\namet consectetur magna tempor do enim"
" aliqua enim tempor adipiscing sit et"
),
]
# msgs = 30 * msgs
return st.selectbox(
"N/A",
msgs,
placeholder="Resend recent chat message",
label_visibility="collapsed",
index=None,
key=f"recent_msgs_{self.state.recent_msgs_num}",
)
def gui_main(): def gui_main():
coder = get_coder() coder = get_coder()
GUI(coder) state = get_state()
GUI(coder, state)
if __name__ == "__main__": if __name__ == "__main__":