mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
roughed in diff-func-string
This commit is contained in:
parent
92c2d47e20
commit
58e763cee7
3 changed files with 41 additions and 7 deletions
|
@ -87,7 +87,7 @@ class Coder:
|
||||||
elif edit_format == "whole-func":
|
elif edit_format == "whole-func":
|
||||||
return WholeFileFunctionCoder(main_model, io, **kwargs)
|
return WholeFileFunctionCoder(main_model, io, **kwargs)
|
||||||
elif edit_format == "diff-func":
|
elif edit_format == "diff-func":
|
||||||
return EditBlockFunctionCoder(main_model, io, **kwargs)
|
return EditBlockFunctionCoder("string", main_model, io, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown edit format {edit_format}")
|
raise ValueError(f"Unknown edit format {edit_format}")
|
||||||
|
|
||||||
|
@ -432,7 +432,7 @@ class Coder:
|
||||||
messages += self.cur_messages
|
messages += self.cur_messages
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
utils.show_messages(messages)
|
utils.show_messages(messages, functions=self.functions)
|
||||||
|
|
||||||
exhausted = False
|
exhausted = False
|
||||||
interrupted = False
|
interrupted = False
|
||||||
|
|
|
@ -38,7 +38,7 @@ class EditBlockFunctionCoder(Coder):
|
||||||
type="string",
|
type="string",
|
||||||
),
|
),
|
||||||
description=(
|
description=(
|
||||||
"Lines from the original file, including all"
|
"Some lines from the original file, including all"
|
||||||
" whitespace, without skipping any lines"
|
" whitespace, without skipping any lines"
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
@ -57,7 +57,29 @@ class EditBlockFunctionCoder(Coder):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, code_format, *args, **kwargs):
|
||||||
|
self.code_format = code_format
|
||||||
|
|
||||||
|
if code_format == "string":
|
||||||
|
original_lines = dict(
|
||||||
|
type="string",
|
||||||
|
description=(
|
||||||
|
"Some lines from the original file, including all"
|
||||||
|
" whitespace and newlines, without skipping any lines"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
updated_lines = dict(
|
||||||
|
type="string",
|
||||||
|
description="New content to replace the `original_lines` with",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.functions[0]["parameters"]["properties"]["edits"]["items"]["properties"][
|
||||||
|
"original_lines"
|
||||||
|
] = original_lines
|
||||||
|
self.functions[0]["parameters"]["properties"]["edits"]["items"]["properties"][
|
||||||
|
"updated_lines"
|
||||||
|
] = updated_lines
|
||||||
|
|
||||||
self.gpt_prompts = EditBlockFunctionPrompts()
|
self.gpt_prompts = EditBlockFunctionPrompts()
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -96,8 +118,17 @@ class EditBlockFunctionCoder(Coder):
|
||||||
edited = set()
|
edited = set()
|
||||||
for edit in edits:
|
for edit in edits:
|
||||||
path = get_arg(edit, "path")
|
path = get_arg(edit, "path")
|
||||||
original = "\n".join(get_arg(edit, "original_lines")) + "\n"
|
original = get_arg(edit, "original_lines")
|
||||||
updated = "\n".join(get_arg(edit, "updated_lines")) + "\n"
|
updated = get_arg(edit, "updated_lines")
|
||||||
|
|
||||||
|
if self.code_format == "list":
|
||||||
|
original = "\n".join(original)
|
||||||
|
updated = "\n".join(updated)
|
||||||
|
|
||||||
|
if original and not original.endswith("\n"):
|
||||||
|
original += "\n"
|
||||||
|
if updated and not updated.endswith("\n"):
|
||||||
|
updated += "\n"
|
||||||
|
|
||||||
full_path = self.allowed_to_edit(path)
|
full_path = self.allowed_to_edit(path)
|
||||||
if not full_path:
|
if not full_path:
|
||||||
|
|
|
@ -19,7 +19,7 @@ def quoted_file(fname, display_fname, fence=("```", "```"), number=False):
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def show_messages(messages, title=None):
|
def show_messages(messages, title=None, functions=None):
|
||||||
if title:
|
if title:
|
||||||
print(title.upper(), "*" * 50)
|
print(title.upper(), "*" * 50)
|
||||||
|
|
||||||
|
@ -32,3 +32,6 @@ def show_messages(messages, title=None):
|
||||||
content = msg.get("function_call")
|
content = msg.get("function_call")
|
||||||
if content:
|
if content:
|
||||||
print(role, content)
|
print(role, content)
|
||||||
|
|
||||||
|
if functions:
|
||||||
|
dump(functions)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue