diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 61ae76734..3c10898de 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -45,6 +45,7 @@ class Coder: abs_fnames = None repo = None last_aider_commit_hash = None + aider_edited_files = None last_asked_for_commit_time = 0 repo_map = None functions = None @@ -546,6 +547,8 @@ class Coder: return messages def send_new_user_message(self, inp): + self.aider_edited_files = None + self.cur_messages += [ dict(role="user", content=inp), ] @@ -605,6 +608,7 @@ class Coder: self.update_cur_messages(edited) if edited: + self.aider_edited_files = edited if self.repo and self.auto_commits and not self.dry_run: saved_message = self.auto_commit(edited) elif hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"): @@ -1055,6 +1059,7 @@ class Coder: if res: commit_hash, commit_message = res self.last_aider_commit_hash = commit_hash + self.last_aider_commit_message = commit_message return self.gpt_prompts.files_content_gpt_edits.format( hash=commit_hash, diff --git a/aider/gui.py b/aider/gui.py index 6e97ae6d8..7846fbc3e 100755 --- a/aider/gui.py +++ b/aider/gui.py @@ -60,9 +60,11 @@ def search(text=None): @st.cache_resource def get_coder(): coder = cli_main(return_coder=True) - if isinstance(coder, Coder): - return coder - raise ValueError() + if not isinstance(coder, Coder): + raise ValueError(coder) + if not coder.repo: + raise ValueError("GUI can currently only be used inside a git repo") + return coder class GUI: @@ -71,39 +73,34 @@ class GUI: lines = " \n".join(lines) st.info(lines) - def mock_tool_output(self): - messages = """Applied edit to new_program.py""" - # st.info(messages) + def show_edit_info(self, edit): + commit_hash = edit.get("commit_hash") + commit_message = edit.get("commit_message") + fnames = edit.get("fnames") + if fnames: + fnames = sorted(fnames) - self.chat_controls.empty() + if not commit_hash and not fnames: + return - self.chat_controls = st.empty() - with self.chat_controls: - container = st.container() + show_undo = False + res = "" + if commit_hash: + res += f"Commit `{commit_hash}`: {commit_message} \n" + if commit_hash == self.coder.last_aider_commit_hash: + show_undo = True - with container: - # cols = st.columns([0.8,0.2]) - # with cols[0]: + if len(fnames) == 1: + res += f"Applied edits to `{fnames[0]}`" + elif len(fnames) > 1: + res += "Applied edits to: \n" + for fname in fnames: + res += f"- `{fname}` \n" - # with st.expander(messages): - # diff = Path("aider/tmp.diff").read_text() - # st.code(diff, language="diff") - with st.expander( - "Commit `33a242c`: Added sample python that highlights language features " - f" \n{messages}" - ): - # st.info(messages) - st.code(diff, language="diff") - st.button( - "Undo commit `33a242c`", - key=random.random(), - help="wtf?", - ) - - if False: - st.button("Allow edits to `foobar.py`", key=random.random(), help="??") - st.button("Allow creation of new file `some/new/file.js`", key=random.random()) - st.button("Add `baz/foo.py` to the chat", key=random.random()) + with st.container(border=True): + st.write(res) + if show_undo: + st.button(f"Undo commit `{commit_hash}`", key=f"undo_{commit_hash}") def do_sidebar(self): with st.sidebar: @@ -112,7 +109,7 @@ class GUI: def do_cmd_tab(self): with self.cmds_tab: - self.do_recommended_actions() + # self.do_recommended_actions() self.do_add_to_chat() self.do_tokens_and_cost() self.do_git() @@ -185,7 +182,7 @@ class GUI: st.button("Clear chat history") # st.metric("Cost of last message send & reply", "$0.0019", help="foo") # st.metric("Cost to send next message", "$0.0013", help="foo") - st.metric("Total cost this session", "$0.22") + # st.metric("Total cost this session", "$0.22") def do_git(self): with st.expander("Git", expanded=False): @@ -222,15 +219,21 @@ class GUI: with self.messages: self.announce() - for msg in st.session_state.messages: - with self.messages.chat_message(msg["role"]): - st.write(msg["content"]) - # self.cost() + for msg in st.session_state.messages: + dump(msg) - self.chat_controls = st.empty() + role = msg["role"] - def clear_controls(self): - self.chat_controls.empty() + if role == "edit": + self.show_edit_info(msg) + elif role == "info": + st.info(msg["message"]) + elif role in ("user", "assistant"): + with st.chat_message(role): + st.write(msg["content"]) + # self.cost() + else: + st.dict(msg) def init_state(self): if "messages" not in st.session_state: @@ -239,6 +242,9 @@ class GUI: if "recent_msgs_num" not in st.session_state: st.session_state.recent_msgs_num = 0 + if "last_aider_commit_hash" not in st.session_state: + st.session_state.last_aider_commit_hash = self.coder.last_aider_commit_hash + def __init__(self, coder): self.coder = coder @@ -254,7 +260,7 @@ class GUI: self.do_cmd_tab() self.do_messages_container() - self.prompt = st.chat_input("Say something", on_submit=self.clear_controls) + self.prompt = st.chat_input("Say something") if self.prompt: self.chat(self.prompt) @@ -271,8 +277,6 @@ class GUI: st.caption(f"${cost:0.4f}") def chat(self, prompt): - self.clear_controls() - st.session_state.messages.append({"role": "user", "content": prompt}) with self.messages.chat_message("user"): st.write(prompt) @@ -283,12 +287,23 @@ class GUI: st.session_state.messages.append({"role": "assistant", "content": res}) # self.cost() if self.coder.reflected_message: + info = dict(role="info", message=self.coder.reflected_message) + st.session_state.messages.append(info) self.messages.info(self.coder.reflected_message) prompt = self.coder.reflected_message - dump(st.session_state.messages) with self.messages: - self.mock_tool_output() + edit = dict( + role="edit", + fnames=self.coder.aider_edited_files, + ) + if st.session_state.last_aider_commit_hash != self.coder.last_aider_commit_hash: + edit["commit_hash"] = self.coder.last_aider_commit_hash + edit["commit_message"] = self.coder.last_aider_commit_message + st.session_state.last_aider_commit_hash = self.coder.last_aider_commit_hash + + st.session_state.messages.append(edit) + self.show_edit_info(edit) def gui_main(): diff --git a/aider/main.py b/aider/main.py index ca778382b..eee9eee8f 100644 --- a/aider/main.py +++ b/aider/main.py @@ -513,7 +513,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F io = InputOutput( args.pretty, - args.yes, + args.yes or return_coder, # Force --yes if return_coder args.input_history_file, args.chat_history_file, input=input,