mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 14:25:00 +00:00
state as a resource
This commit is contained in:
parent
1c43722221
commit
b4d1a0a2a2
1 changed files with 62 additions and 56 deletions
118
aider/gui.py
118
aider/gui.py
|
@ -11,37 +11,6 @@ from aider.dump import dump # noqa: F401
|
|||
from aider.main import main as cli_main
|
||||
|
||||
|
||||
def init_state(key, val=None):
|
||||
if key in st.session_state:
|
||||
return
|
||||
setattr(st.session_state, key, val)
|
||||
|
||||
|
||||
def recent_msgs():
|
||||
msgs = [
|
||||
"write a python program that shows off some python features",
|
||||
"write a tsx program that shows off some language features",
|
||||
"refactor the Frobulator.simplify method to be a stand alone function",
|
||||
"lorem ipsum dolor",
|
||||
"lorem adipiscing adipiscing et dolore sit elit aliqua dolore ut incididunt",
|
||||
(
|
||||
"sed magna consectetur et quis do magna labore ad elit et elit ad eiusmod sed labore"
|
||||
" aliqua eiusmod enim ad nostrud\n\namet consectetur magna tempor do enim aliqua enim"
|
||||
" tempor adipiscing sit et"
|
||||
),
|
||||
]
|
||||
# msgs = 30 * msgs
|
||||
|
||||
return st.selectbox(
|
||||
"N/A",
|
||||
msgs,
|
||||
placeholder="Resend recent chat message",
|
||||
label_visibility="collapsed",
|
||||
index=None,
|
||||
key=f"recent_msgs_{st.session_state.recent_msgs_num}",
|
||||
)
|
||||
|
||||
|
||||
def search(text=None):
|
||||
results = []
|
||||
for root, _, files in os.walk("aider"):
|
||||
|
@ -54,16 +23,24 @@ def search(text=None):
|
|||
return results
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def get_nominal_cached_data():
|
||||
print("get_nominal_cached_data")
|
||||
# Keep state as a resource, which survives browser reloads (since Coder does too)
|
||||
class State:
|
||||
keys = set()
|
||||
|
||||
def init(self, key, val=None):
|
||||
if key in self.keys:
|
||||
return
|
||||
self.keys.add(key)
|
||||
setattr(self, key, val)
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_state():
|
||||
return State()
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_coder():
|
||||
print("get_coder")
|
||||
get_nominal_cached_data()
|
||||
|
||||
coder = cli_main(return_coder=True)
|
||||
if not isinstance(coder, Coder):
|
||||
raise ValueError(coder)
|
||||
|
@ -232,7 +209,7 @@ class GUI:
|
|||
def reset_recent_msgs(self):
|
||||
self.recent_msgs_empty.empty()
|
||||
with self.recent_msgs_empty:
|
||||
self.old_prompt = recent_msgs()
|
||||
self.old_prompt = self.recent_msgs()
|
||||
|
||||
def do_messages_container(self):
|
||||
self.messages = st.container()
|
||||
|
@ -243,7 +220,7 @@ class GUI:
|
|||
with self.messages:
|
||||
self.announce()
|
||||
|
||||
for msg in st.session_state.messages:
|
||||
for msg in self.state.messages:
|
||||
dump(msg)
|
||||
|
||||
role = msg["role"]
|
||||
|
@ -261,18 +238,22 @@ class GUI:
|
|||
|
||||
def initialize_state(self):
|
||||
messages = [{"role": "assistant", "content": "How can I help you?"}]
|
||||
init_state("messages", messages)
|
||||
init_state("recent_msgs_num", 0)
|
||||
init_state("last_aider_commit_hash", self.coder.last_aider_commit_hash)
|
||||
init_state("prompt")
|
||||
self.state.init("messages", messages)
|
||||
self.state.init("recent_msgs_num", 0)
|
||||
self.state.init("last_aider_commit_hash", self.coder.last_aider_commit_hash)
|
||||
self.state.init("prompt")
|
||||
|
||||
dump(self.state.messages)
|
||||
|
||||
def button(self, args, **kwargs):
|
||||
"Create a button, disabled if prompt pending"
|
||||
kwargs["disabled"] = self.prompt_pending()
|
||||
return st.button(args, **kwargs)
|
||||
|
||||
def __init__(self, coder):
|
||||
def __init__(self, coder, state):
|
||||
self.coder = coder
|
||||
self.state = state
|
||||
|
||||
self.last_undo_button = None
|
||||
|
||||
# Force the coder to cooperate, regardless of cmd line args
|
||||
|
@ -298,22 +279,22 @@ class GUI:
|
|||
|
||||
if self.old_prompt:
|
||||
prompt = self.old_prompt
|
||||
st.session_state.recent_msgs_num += 1
|
||||
self.state.recent_msgs_num += 1
|
||||
self.reset_recent_msgs()
|
||||
self.chat(prompt)
|
||||
return
|
||||
|
||||
def prompt_pending(self):
|
||||
return st.session_state.prompt is not None
|
||||
return self.state.prompt is not None
|
||||
|
||||
def cost(self):
|
||||
cost = random.random() * 0.003 + 0.001
|
||||
st.caption(f"${cost:0.4f}")
|
||||
|
||||
def chat(self, prompt):
|
||||
st.session_state.prompt = prompt
|
||||
self.state.prompt = prompt
|
||||
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
self.state.messages.append({"role": "user", "content": prompt})
|
||||
with self.messages.chat_message("user"):
|
||||
st.write(prompt)
|
||||
|
||||
|
@ -321,17 +302,17 @@ class GUI:
|
|||
st.experimental_rerun()
|
||||
|
||||
def process_chat(self):
|
||||
prompt = st.session_state.prompt
|
||||
st.session_state.prompt = None
|
||||
prompt = self.state.prompt
|
||||
self.state.prompt = None
|
||||
|
||||
while prompt:
|
||||
with self.messages.chat_message("assistant"):
|
||||
res = st.write_stream(self.coder.run_stream(prompt))
|
||||
st.session_state.messages.append({"role": "assistant", "content": res})
|
||||
self.state.messages.append({"role": "assistant", "content": res})
|
||||
# self.cost()
|
||||
if self.coder.reflected_message:
|
||||
info = dict(role="info", message=self.coder.reflected_message)
|
||||
st.session_state.messages.append(info)
|
||||
self.state.messages.append(info)
|
||||
self.messages.info(self.coder.reflected_message)
|
||||
prompt = self.coder.reflected_message
|
||||
|
||||
|
@ -340,7 +321,7 @@ class GUI:
|
|||
role="edit",
|
||||
fnames=self.coder.aider_edited_files,
|
||||
)
|
||||
if st.session_state.last_aider_commit_hash != self.coder.last_aider_commit_hash:
|
||||
if self.state.last_aider_commit_hash != self.coder.last_aider_commit_hash:
|
||||
edit["commit_hash"] = self.coder.last_aider_commit_hash
|
||||
edit["commit_message"] = self.coder.last_aider_commit_message
|
||||
commits = f"{self.coder.last_aider_commit_hash}~1"
|
||||
|
@ -350,18 +331,43 @@ class GUI:
|
|||
self.coder.last_aider_commit_hash,
|
||||
)
|
||||
edit["diff"] = diff
|
||||
st.session_state.last_aider_commit_hash = self.coder.last_aider_commit_hash
|
||||
self.state.last_aider_commit_hash = self.coder.last_aider_commit_hash
|
||||
|
||||
st.session_state.messages.append(edit)
|
||||
self.state.messages.append(edit)
|
||||
self.show_edit_info(edit)
|
||||
|
||||
# re-render the UI for the non-prompt_pending state
|
||||
st.experimental_rerun()
|
||||
|
||||
def recent_msgs(self):
|
||||
msgs = [
|
||||
"write a python program that shows off some python features",
|
||||
"write a tsx program that shows off some language features",
|
||||
"refactor the Frobulator.simplify method to be a stand alone function",
|
||||
"lorem ipsum dolor",
|
||||
"lorem adipiscing adipiscing et dolore sit elit aliqua dolore ut incididunt",
|
||||
(
|
||||
"sed magna consectetur et quis do magna labore ad elit et elit ad eiusmod sed"
|
||||
" labore aliqua eiusmod enim ad nostrud\n\namet consectetur magna tempor do enim"
|
||||
" aliqua enim tempor adipiscing sit et"
|
||||
),
|
||||
]
|
||||
# msgs = 30 * msgs
|
||||
|
||||
return st.selectbox(
|
||||
"N/A",
|
||||
msgs,
|
||||
placeholder="Resend recent chat message",
|
||||
label_visibility="collapsed",
|
||||
index=None,
|
||||
key=f"recent_msgs_{self.state.recent_msgs_num}",
|
||||
)
|
||||
|
||||
|
||||
def gui_main():
|
||||
coder = get_coder()
|
||||
GUI(coder)
|
||||
state = get_state()
|
||||
GUI(coder, state)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue