refac and func update_files

This commit is contained in:
Paul Gauthier 2023-06-21 21:17:12 -07:00
parent 8c1e1c7267
commit eb062cc504
4 changed files with 35 additions and 10 deletions

View file

@ -377,7 +377,7 @@ class Coder:
self.cur_messages += [dict(role="assistant", content=content)] self.cur_messages += [dict(role="assistant", content=content)]
return return
edited, edit_error = self.apply_updates(content) edited, edit_error = self.apply_updates()
if edit_error: if edit_error:
return edit_error return edit_error
@ -544,7 +544,7 @@ class Coder:
md = Markdown(show_resp, style=self.assistant_output_color, code_theme="default") md = Markdown(show_resp, style=self.assistant_output_color, code_theme="default")
live.update(md) live.update(md)
def modify_incremental_response(self): def modify_incremental_response(self, final):
return self.partial_response_content return self.partial_response_content
def get_context_from_history(self, history): def get_context_from_history(self, history):
@ -716,9 +716,9 @@ class Coder:
def get_addable_relative_files(self): def get_addable_relative_files(self):
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files()) return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
def apply_updates(self, content): def apply_updates(self):
try: try:
edited = self.update_files(content) edited = self.update_files()
return edited, None return edited, None
except ValueError as err: except ValueError as err:
err = err.args[0] err = err.args[0]

View file

@ -16,7 +16,9 @@ class EditBlockCoder(Coder):
def update_cur_messages(self, content, edited): def update_cur_messages(self, content, edited):
self.cur_messages += [dict(role="assistant", content=content)] self.cur_messages += [dict(role="assistant", content=content)]
def update_files(self, content): def update_files(self):
content = self.partial_response_content
# might raise ValueError for malformed ORIG/UPD blocks # might raise ValueError for malformed ORIG/UPD blocks
edits = list(find_original_update_blocks(content)) edits = list(find_original_update_blocks(content))

View file

@ -1,4 +1,5 @@
import os import os
from pathlib import Path
from aider import diffs from aider import diffs
@ -102,5 +103,26 @@ class FunctionCoder(Coder):
return "\n".join(show_diff) return "\n".join(show_diff)
def update_files(self, content): def update_files(self):
pass args = self.parse_partial_args()
if not args:
return
files = args.get("files", [])
edited = set()
for file_upd in files:
path = file_upd.get("path")
if not path:
raise ValueError(f"Missing path: {file_upd}")
content = file_upd.get("content")
if not content:
raise ValueError(f"Missing content: {file_upd}")
full_path = os.path.abspath(os.path.join(self.root, path))
Path(full_path).write_text(content)
edited.add(path)
return edited

View file

@ -21,10 +21,11 @@ class WholeFileCoder(Coder):
self.cur_messages += [dict(role="assistant", content=content)] self.cur_messages += [dict(role="assistant", content=content)]
def modify_incremental_response(self, final): def modify_incremental_response(self, final):
resp = self.partial_response_content return self.update_files(mode="diff")
return self.update_files(resp, mode="diff")
def update_files(self, mode="update"):
content = self.partial_response_content
def update_files(self, content, mode="update"):
edited = set() edited = set()
chat_files = self.get_inchat_relative_files() chat_files = self.get_inchat_relative_files()
if not chat_files: if not chat_files: