diff --git a/aider/gui.py b/aider/gui.py index cef9ed8f9..bb862ffba 100755 --- a/aider/gui.py +++ b/aider/gui.py @@ -11,37 +11,6 @@ from aider.dump import dump # noqa: F401 from aider.main import main as cli_main -def init_state(key, val=None): - if key in st.session_state: - return - setattr(st.session_state, key, val) - - -def recent_msgs(): - msgs = [ - "write a python program that shows off some python features", - "write a tsx program that shows off some language features", - "refactor the Frobulator.simplify method to be a stand alone function", - "lorem ipsum dolor", - "lorem adipiscing adipiscing et dolore sit elit aliqua dolore ut incididunt", - ( - "sed magna consectetur et quis do magna labore ad elit et elit ad eiusmod sed labore" - " aliqua eiusmod enim ad nostrud\n\namet consectetur magna tempor do enim aliqua enim" - " tempor adipiscing sit et" - ), - ] - # msgs = 30 * msgs - - return st.selectbox( - "N/A", - msgs, - placeholder="Resend recent chat message", - label_visibility="collapsed", - index=None, - key=f"recent_msgs_{st.session_state.recent_msgs_num}", - ) - - def search(text=None): results = [] for root, _, files in os.walk("aider"): @@ -54,16 +23,24 @@ def search(text=None): return results -@st.cache_data -def get_nominal_cached_data(): - print("get_nominal_cached_data") +# Keep state as a resource, which survives browser reloads (since Coder does too) +class State: + keys = set() + + def init(self, key, val=None): + if key in self.keys: + return + self.keys.add(key) + setattr(self, key, val) + + +@st.cache_resource +def get_state(): + return State() @st.cache_resource def get_coder(): - print("get_coder") - get_nominal_cached_data() - coder = cli_main(return_coder=True) if not isinstance(coder, Coder): raise ValueError(coder) @@ -232,7 +209,7 @@ class GUI: def reset_recent_msgs(self): self.recent_msgs_empty.empty() with self.recent_msgs_empty: - self.old_prompt = recent_msgs() + self.old_prompt = self.recent_msgs() def do_messages_container(self): self.messages = st.container() @@ -243,7 +220,7 @@ class GUI: with self.messages: self.announce() - for msg in st.session_state.messages: + for msg in self.state.messages: dump(msg) role = msg["role"] @@ -261,18 +238,22 @@ class GUI: def initialize_state(self): messages = [{"role": "assistant", "content": "How can I help you?"}] - init_state("messages", messages) - init_state("recent_msgs_num", 0) - init_state("last_aider_commit_hash", self.coder.last_aider_commit_hash) - init_state("prompt") + self.state.init("messages", messages) + self.state.init("recent_msgs_num", 0) + self.state.init("last_aider_commit_hash", self.coder.last_aider_commit_hash) + self.state.init("prompt") + + dump(self.state.messages) def button(self, args, **kwargs): "Create a button, disabled if prompt pending" kwargs["disabled"] = self.prompt_pending() return st.button(args, **kwargs) - def __init__(self, coder): + def __init__(self, coder, state): self.coder = coder + self.state = state + self.last_undo_button = None # Force the coder to cooperate, regardless of cmd line args @@ -298,22 +279,22 @@ class GUI: if self.old_prompt: prompt = self.old_prompt - st.session_state.recent_msgs_num += 1 + self.state.recent_msgs_num += 1 self.reset_recent_msgs() self.chat(prompt) return def prompt_pending(self): - return st.session_state.prompt is not None + return self.state.prompt is not None def cost(self): cost = random.random() * 0.003 + 0.001 st.caption(f"${cost:0.4f}") def chat(self, prompt): - st.session_state.prompt = prompt + self.state.prompt = prompt - st.session_state.messages.append({"role": "user", "content": prompt}) + self.state.messages.append({"role": "user", "content": prompt}) with self.messages.chat_message("user"): st.write(prompt) @@ -321,17 +302,17 @@ class GUI: st.experimental_rerun() def process_chat(self): - prompt = st.session_state.prompt - st.session_state.prompt = None + prompt = self.state.prompt + self.state.prompt = None while prompt: with self.messages.chat_message("assistant"): res = st.write_stream(self.coder.run_stream(prompt)) - st.session_state.messages.append({"role": "assistant", "content": res}) + self.state.messages.append({"role": "assistant", "content": res}) # self.cost() if self.coder.reflected_message: info = dict(role="info", message=self.coder.reflected_message) - st.session_state.messages.append(info) + self.state.messages.append(info) self.messages.info(self.coder.reflected_message) prompt = self.coder.reflected_message @@ -340,7 +321,7 @@ class GUI: role="edit", fnames=self.coder.aider_edited_files, ) - if st.session_state.last_aider_commit_hash != self.coder.last_aider_commit_hash: + if self.state.last_aider_commit_hash != self.coder.last_aider_commit_hash: edit["commit_hash"] = self.coder.last_aider_commit_hash edit["commit_message"] = self.coder.last_aider_commit_message commits = f"{self.coder.last_aider_commit_hash}~1" @@ -350,18 +331,43 @@ class GUI: self.coder.last_aider_commit_hash, ) edit["diff"] = diff - st.session_state.last_aider_commit_hash = self.coder.last_aider_commit_hash + self.state.last_aider_commit_hash = self.coder.last_aider_commit_hash - st.session_state.messages.append(edit) + self.state.messages.append(edit) self.show_edit_info(edit) # re-render the UI for the non-prompt_pending state st.experimental_rerun() + def recent_msgs(self): + msgs = [ + "write a python program that shows off some python features", + "write a tsx program that shows off some language features", + "refactor the Frobulator.simplify method to be a stand alone function", + "lorem ipsum dolor", + "lorem adipiscing adipiscing et dolore sit elit aliqua dolore ut incididunt", + ( + "sed magna consectetur et quis do magna labore ad elit et elit ad eiusmod sed" + " labore aliqua eiusmod enim ad nostrud\n\namet consectetur magna tempor do enim" + " aliqua enim tempor adipiscing sit et" + ), + ] + # msgs = 30 * msgs + + return st.selectbox( + "N/A", + msgs, + placeholder="Resend recent chat message", + label_visibility="collapsed", + index=None, + key=f"recent_msgs_{self.state.recent_msgs_num}", + ) + def gui_main(): coder = get_coder() - GUI(coder) + state = get_state() + GUI(coder, state) if __name__ == "__main__":