This commit is contained in:
Paul Gauthier 2024-08-07 07:37:16 -03:00
parent 492738f325
commit 47295a1545
4 changed files with 32 additions and 18 deletions

View file

@ -1,10 +1,11 @@
from .ask_coder import AskCoder
from .base_coder import Coder
from .editblock_coder import EditBlockCoder
from .editblock_fenced_coder import EditBlockFencedCoder
from .help_coder import HelpCoder
from .single_wholefile_func_coder import SingleWholeFileFunctionCoder
from .udiff_coder import UnifiedDiffCoder
from .wholefile_coder import WholeFileCoder
from .ask_coder import AskCoder
__all__ = [
HelpCoder,
@ -14,4 +15,5 @@ __all__ = [
EditBlockFencedCoder,
WholeFileCoder,
UnifiedDiffCoder,
SingleWholeFileFunctionCoder,
]

View file

@ -1209,8 +1209,12 @@ class Coder:
show_func_err = None
show_content_err = None
try:
self.partial_response_function_call = completion.choices[0].message.function_call
self.partial_response_function_call = (
completion.choices[0].message.tool_calls[0].function
)
dump(str(self.partial_response_function_call))
except AttributeError as func_err:
dump(func_err)
show_func_err = func_err
try:
@ -1219,7 +1223,7 @@ class Coder:
show_content_err = content_err
resp_hash = dict(
function_call=self.partial_response_function_call,
function_call=str(self.partial_response_function_call),
content=self.partial_response_content,
)
resp_hash = hashlib.sha1(json.dumps(resp_hash, sort_keys=True).encode())

View file

@ -1,3 +1,5 @@
import json
from aider import diffs
from ..dump import dump # noqa: F401
@ -6,6 +8,8 @@ from .single_wholefile_func_prompts import SingleWholeFileFunctionPrompts
class SingleWholeFileFunctionCoder(Coder):
edit_format = "func"
functions = [
dict(
name="write_file",
@ -31,7 +35,6 @@ class SingleWholeFileFunctionCoder(Coder):
]
def __init__(self, *args, **kwargs):
raise RuntimeError("Deprecated, needs to be refactored to support get_edits/apply_edits")
self.gpt_prompts = SingleWholeFileFunctionPrompts()
super().__init__(*args, **kwargs)
@ -44,12 +47,18 @@ class SingleWholeFileFunctionCoder(Coder):
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def render_incremental_response(self, final=False):
res = ""
if self.partial_response_content:
return self.partial_response_content
res += self.partial_response_content
args = self.parse_partial_args()
return str(args)
for k, v in args.items():
res += "\n"
res += f"{k}:\n"
res += v
return res
if not args:
return
@ -95,18 +104,17 @@ class SingleWholeFileFunctionCoder(Coder):
return "\n".join(show_diff)
def _update_files(self):
name = self.partial_response_function_call.get("name")
if name and name != "write_file":
raise ValueError(f'Unknown function_call name="{name}", use name="write_file"')
def get_edits(self):
chat_files = self.get_inchat_relative_files()
assert len(chat_files) == 1, chat_files
args = self.parse_partial_args()
if not args:
return
content = args["content"]
path = self.get_inchat_relative_files()[0]
if self.allowed_to_edit(path, content):
return set([path])
res = chat_files[0], args["content"]
dump(res)
return [res]
return set()
def apply_edits(self, edits):
for path, content in edits:
full_path = self.abs_root_path(path)
self.io.write_text(full_path, content)

View file

@ -55,7 +55,7 @@ def send_with_retries(
stream=stream,
)
if functions is not None:
kwargs["functions"] = functions
kwargs["tools"] = [dict(type="functions", function=functions[0])]
if extra_headers is not None:
kwargs["extra_headers"] = extra_headers
if max_tokens is not None: