This commit is contained in:
Paul Gauthier 2023-06-20 16:47:37 -07:00
parent bc38fcde65
commit 02c9a30c45
4 changed files with 25 additions and 12 deletions

View file

@ -1,3 +1,5 @@
from .base import Coder
from .editblock import EditBlockCoder
from .wholefile import WholeFileCoder
__all__ = [Coder]
__all__ = [Coder, EditBlockCoder, WholeFileCoder]

View file

@ -14,7 +14,7 @@ from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from aider import diffs, editors, models, prompts, utils
from aider import diffs, models, prompts, utils
from aider.commands import Commands
from aider.repomap import RepoMap
@ -32,6 +32,17 @@ class Coder:
last_asked_for_commit_time = 0
repo_map = None
@classmethod
def create(self, edit_format, *args, **kwargs):
from . import EditBlockCoder, WholeFileCoder
if edit_format == "whole":
return EditBlockCoder(*args, **kwargs)
elif edit_format == "diff":
return WholeFileCoder(*args, **kwargs)
else:
raise ValueError(f"Unknown edit format {edit_format}")
def check_model_availability(self, main_model):
available_models = openai.Model.list()
model_ids = [model.id for model in available_models["data"]]
@ -93,15 +104,6 @@ class Coder:
self.main_model = main_model
self.edit_format = self.main_model.edit_format
if self.edit_format == "whole":
self.gpt_prompts = editors.WholeFilePrompts()
elif self.edit_format == "diff":
self.gpt_prompts = editors.EditBlockPrompts()
else:
raise ValueError(
f"Model {main_model} requesting unknown edit format {self.edit_format}"
)
self.io.tool_output(f"Model: {main_model.name}")
self.show_diffs = show_diffs

View file

@ -0,0 +1,8 @@
from ..editors import WholeFilePrompts
from .base import Coder
class WholeFileCoder(Coder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gpt_prompts = WholeFilePrompts()

View file

@ -227,7 +227,8 @@ def main(args=None, input=None, output=None):
io.tool_error("No OpenAI API key provided. Use --openai-api-key or env OPENAI_API_KEY.")
return 1
coder = Coder(
coder = Coder.create(
"diff",
io,
main_model=args.model,
fnames=args.files,