This commit is contained in:
Paul Gauthier 2023-06-27 19:38:14 -07:00
parent f22ccf6195
commit 1ac366aa62
2 changed files with 29 additions and 70 deletions

View file

@ -13,7 +13,7 @@ import git
import openai import openai
import requests import requests
from openai.error import APIError, RateLimitError, ServiceUnavailableError from openai.error import APIError, RateLimitError, ServiceUnavailableError
from rich.console import Console from rich.console import Console, Text
from rich.live import Live from rich.live import Live
from rich.markdown import Markdown from rich.markdown import Markdown
@ -491,6 +491,9 @@ class Coder:
if add_rel_files_message: if add_rel_files_message:
return add_rel_files_message return add_rel_files_message
def update_cur_messages(self, content, edited):
self.cur_messages += [dict(role="assistant", content=content)]
def auto_commit(self): def auto_commit(self):
res = self.commit(history=self.cur_messages, prefix="aider: ") res = self.commit(history=self.cur_messages, prefix="aider: ")
if res: if res:
@ -646,6 +649,8 @@ class Coder:
show_resp = self.render_incremental_response(True) show_resp = self.render_incremental_response(True)
if self.pretty: if self.pretty:
show_resp = Markdown(show_resp, style=self.assistant_output_color, code_theme="default") show_resp = Markdown(show_resp, style=self.assistant_output_color, code_theme="default")
else:
show_resp = Text(show_resp)
self.io.console.print(show_resp) self.io.console.print(show_resp)
self.io.console.print(tokens) self.io.console.print(tokens)

View file

@ -1,9 +1,11 @@
import json
import os import os
from aider import diffs from aider import diffs
from ..dump import dump # noqa: F401 from ..dump import dump # noqa: F401
from .base_coder import Coder from .base_coder import Coder
from .editblock_coder import do_replace
from .editblock_func_prompts import EditBlockFunctionPrompts from .editblock_func_prompts import EditBlockFunctionPrompts
@ -56,92 +58,44 @@ class EditBlockFunctionCoder(Coder):
self.gpt_prompts = EditBlockFunctionPrompts() self.gpt_prompts = EditBlockFunctionPrompts()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def update_cur_messages(self, content, edited):
if edited:
self.cur_messages += [
dict(role="assistant", content=self.gpt_prompts.redacted_edit_message)
]
else:
self.cur_messages += [dict(role="assistant", content=content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final=False): def render_incremental_response(self, final=False):
if self.partial_response_content: if self.partial_response_content:
return self.partial_response_content return self.partial_response_content
args = self.parse_partial_args() args = self.parse_partial_args()
res = json.dumps(args, indent=4)
if not args: return "```\n" + res + "\n```\n"
return
explanation = args.get("explanation")
files = args.get("files", [])
res = ""
if explanation:
res += f"{explanation}\n\n"
for i, file_upd in enumerate(files):
path = file_upd.get("path")
if not path:
continue
content = file_upd.get("content")
if not content:
continue
this_final = (i < len(files) - 1) or final
res += self.live_diffs(path, content, this_final)
return res
def live_diffs(self, fname, content, final):
lines = content.splitlines(keepends=True)
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
lines,
final,
fname=fname,
).splitlines()
return "\n".join(show_diff)
def update_files(self): def update_files(self):
name = self.partial_response_function_call.get("name") name = self.partial_response_function_call.get("name")
if name and name != "replace_lines": if name and name != "replace_lines":
raise ValueError(f'Unknown function_call name="{name}", use name="write_file"') raise ValueError(f'Unknown function_call name="{name}", use name="replace_lines"')
args = self.parse_partial_args() args = self.parse_partial_args()
if not args: if not args:
return return
files = args.get("files", []) edits = args.get("edits", [])
edited = set() edited = set()
for file_upd in files: for edit in edits:
path = file_upd.get("path") path = get_arg(edit, "path")
if not path: original = get_arg(edit, "original_lines")
raise ValueError(f"Missing path parameter: {file_upd}") updated = get_arg(edit, "updated_lines")
content = file_upd.get("content") full_path = self.allowed_to_edit(path)
if not content: if not full_path:
raise ValueError(f"Missing content parameter: {file_upd}") continue
if do_replace(full_path, original, updated, self.dry_run):
if self.allowed_to_edit(path, content):
edited.add(path) edited.add(path)
continue
self.io.tool_error(f"Failed to apply edit to {path}")
return edited return edited
def get_arg(edit, arg):
if arg not in edit:
raise ValueError(f"Missing `{arg}` parameter: {edit}")
return edit[arg]