roughed in func coder

This commit is contained in:
Paul Gauthier 2023-06-21 12:16:51 -07:00
parent 75e84017b3
commit 3465bd57d1
5 changed files with 130 additions and 3 deletions

View file

@ -1,5 +1,6 @@
from .base_coder import Coder from .base_coder import Coder
from .editblock_coder import EditBlockCoder from .editblock_coder import EditBlockCoder
from .func_coder import FunctionCoder
from .wholefile_coder import WholeFileCoder from .wholefile_coder import WholeFileCoder
__all__ = [Coder, EditBlockCoder, WholeFileCoder] __all__ = [Coder, EditBlockCoder, WholeFileCoder, FunctionCoder]

View file

@ -42,7 +42,7 @@ class Coder:
openai_api_base="https://api.openai.com/v1", openai_api_base="https://api.openai.com/v1",
**kwargs, **kwargs,
): ):
from . import EditBlockCoder, WholeFileCoder from . import EditBlockCoder, FunctionCoder, WholeFileCoder
openai.api_key = openai_api_key openai.api_key = openai_api_key
openai.api_base = openai_api_base openai.api_base = openai_api_base
@ -63,8 +63,10 @@ class Coder:
return EditBlockCoder(main_model, io, **kwargs) return EditBlockCoder(main_model, io, **kwargs)
elif edit_format == "whole": elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs) return WholeFileCoder(main_model, io, **kwargs)
elif edit_format == "func":
return FunctionCoder(main_model, io, **kwargs)
else: else:
raise ValueError(f"{main_model} has unknown edit format {main_model.edit_format}") raise ValueError(f"Unknown edit format {edit_format}")
def __init__( def __init__(
self, self,

100
aider/coders/func_coder.py Normal file
View file

@ -0,0 +1,100 @@
import os
from pathlib import Path
from aider import diffs
from .base_coder import Coder
from .func_prompts import FunctionPrompts
class FunctionCoder(Coder):
def __init__(self, *args, **kwargs):
self.gpt_prompts = FunctionPrompts()
super().__init__(*args, **kwargs)
def update_cur_messages(self, content, edited):
if edited:
self.cur_messages += [
dict(role="assistant", content=self.gpt_prompts.redacted_edit_message)
]
else:
self.cur_messages += [dict(role="assistant", content=content)]
def modify_incremental_response(self, resp):
return self.update_files(resp, mode="diff")
def update_files(self, content, mode="update"):
edited = set()
chat_files = self.get_inchat_relative_files()
if not chat_files:
if mode == "diff":
return content
return
output = []
lines = content.splitlines(keepends=True)
fname = None
new_lines = []
for i, line in enumerate(lines):
if line.startswith("```"):
if fname:
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
if mode == "diff":
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
new_lines,
final=True,
).splitlines()
output += show_diff
else:
new_lines = "".join(new_lines)
Path(full_path).write_text(new_lines)
edited.add(fname)
fname = None
new_lines = []
continue
# starting a new block
if i == 0:
raise ValueError("No filename provided before ``` block")
fname = lines[i - 1].strip()
if fname not in chat_files:
if len(chat_files) == 1:
fname = list(chat_files)[0]
else:
show_chat_files = " ".join(chat_files)
raise ValueError(f"{fname} is not one of: {show_chat_files}")
elif fname:
new_lines.append(line)
else:
output.append(line)
if mode == "diff":
if fname:
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
if mode == "diff":
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
new_lines,
).splitlines()
output += show_diff
return "\n".join(output)
if fname:
raise ValueError("Started a ``` block without closing it")
return edited

View file

@ -0,0 +1,21 @@
# flake8: noqa: E501
from .prompts_base import EditorPrompts
class FunctionPrompts(EditorPrompts):
main_system = """Act as an expert software developer.
Take requests for changes to the supplied code.
If the request is ambiguous, ask questions.
Once you understand the request you MUST:
1. Explain any needed changes.
2. Call functions to edit the code to make the needed changes.
"""
system_reminder = ""
files_content_prefix = "Here is the current content of the files:\n"
files_no_full_files = "I am not sharing any files yet."
redacted_edit_message = "No changes are needed."

View file

@ -36,6 +36,9 @@ class Model:
def is_gpt35(self): def is_gpt35(self):
return self.name.startswith("gpt-3.5-turbo") return self.name.startswith("gpt-3.5-turbo")
def __str__(self):
return self.name
GPT4 = Model("gpt-4", 8) GPT4 = Model("gpt-4", 8)
GPT35 = Model("gpt-3.5-turbo") GPT35 = Model("gpt-3.5-turbo")