roughed in diff-func-string

This commit is contained in:
Paul Gauthier 2023-06-29 15:10:33 -07:00
parent 92c2d47e20
commit 58e763cee7
3 changed files with 41 additions and 7 deletions

View file

@ -87,7 +87,7 @@ class Coder:
elif edit_format == "whole-func": elif edit_format == "whole-func":
return WholeFileFunctionCoder(main_model, io, **kwargs) return WholeFileFunctionCoder(main_model, io, **kwargs)
elif edit_format == "diff-func": elif edit_format == "diff-func":
return EditBlockFunctionCoder(main_model, io, **kwargs) return EditBlockFunctionCoder("string", main_model, io, **kwargs)
else: else:
raise ValueError(f"Unknown edit format {edit_format}") raise ValueError(f"Unknown edit format {edit_format}")
@ -432,7 +432,7 @@ class Coder:
messages += self.cur_messages messages += self.cur_messages
if self.verbose: if self.verbose:
utils.show_messages(messages) utils.show_messages(messages, functions=self.functions)
exhausted = False exhausted = False
interrupted = False interrupted = False

View file

@ -38,7 +38,7 @@ class EditBlockFunctionCoder(Coder):
type="string", type="string",
), ),
description=( description=(
"Lines from the original file, including all" "Some lines from the original file, including all"
" whitespace, without skipping any lines" " whitespace, without skipping any lines"
), ),
), ),
@ -57,7 +57,29 @@ class EditBlockFunctionCoder(Coder):
), ),
] ]
def __init__(self, *args, **kwargs): def __init__(self, code_format, *args, **kwargs):
self.code_format = code_format
if code_format == "string":
original_lines = dict(
type="string",
description=(
"Some lines from the original file, including all"
" whitespace and newlines, without skipping any lines"
),
)
updated_lines = dict(
type="string",
description="New content to replace the `original_lines` with",
)
self.functions[0]["parameters"]["properties"]["edits"]["items"]["properties"][
"original_lines"
] = original_lines
self.functions[0]["parameters"]["properties"]["edits"]["items"]["properties"][
"updated_lines"
] = updated_lines
self.gpt_prompts = EditBlockFunctionPrompts() self.gpt_prompts = EditBlockFunctionPrompts()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -96,8 +118,17 @@ class EditBlockFunctionCoder(Coder):
edited = set() edited = set()
for edit in edits: for edit in edits:
path = get_arg(edit, "path") path = get_arg(edit, "path")
original = "\n".join(get_arg(edit, "original_lines")) + "\n" original = get_arg(edit, "original_lines")
updated = "\n".join(get_arg(edit, "updated_lines")) + "\n" updated = get_arg(edit, "updated_lines")
if self.code_format == "list":
original = "\n".join(original)
updated = "\n".join(updated)
if original and not original.endswith("\n"):
original += "\n"
if updated and not updated.endswith("\n"):
updated += "\n"
full_path = self.allowed_to_edit(path) full_path = self.allowed_to_edit(path)
if not full_path: if not full_path:

View file

@ -19,7 +19,7 @@ def quoted_file(fname, display_fname, fence=("```", "```"), number=False):
return prompt return prompt
def show_messages(messages, title=None): def show_messages(messages, title=None, functions=None):
if title: if title:
print(title.upper(), "*" * 50) print(title.upper(), "*" * 50)
@ -32,3 +32,6 @@ def show_messages(messages, title=None):
content = msg.get("function_call") content = msg.get("function_call")
if content: if content:
print(role, content) print(role, content)
if functions:
dump(functions)