diff --git a/aider/coders/base.py b/aider/coders/base.py index 31a7061c6..d00c5483c 100755 --- a/aider/coders/base.py +++ b/aider/coders/base.py @@ -33,25 +33,32 @@ class Coder: repo_map = None @classmethod - def create(self, edit_format, *args, **kwargs): + def create(self, main_model, io, openai_api_key, openai_api_base, **kwargs): from . import EditBlockCoder, WholeFileCoder - if edit_format == "diff": - return EditBlockCoder(*args, **kwargs) - elif edit_format == "whole": - return WholeFileCoder(*args, **kwargs) - else: - raise ValueError(f"Unknown edit format {edit_format}") + openai.api_key = openai_api_key + openai.api_base = openai_api_base - def check_model_availability(self, main_model): - available_models = openai.Model.list() - model_ids = [model.id for model in available_models["data"]] - return main_model.name in model_ids + if not main_model.always_available: + if not check_model_availability(main_model): + if main_model != models.GPT4: + io.tool_error( + f"API key does not support {main_model.name}, falling back to" + f" {models.GPT35_16k.name}" + ) + main_model = models.GPT35_16k + + if main_model.edit_format == "diff": + return EditBlockCoder(main_model, io, **kwargs) + elif main_model.edit_format == "whole": + return WholeFileCoder(main_model, io, **kwargs) + else: + raise ValueError(f"{main_model} has unknown edit format {main_model.edit_format}") def __init__( self, + main_model, io, - main_model=models.GPT4.name, fnames=None, pretty=True, show_diffs=False, @@ -60,15 +67,8 @@ class Coder: dry_run=False, map_tokens=1024, verbose=False, - openai_api_key=None, - openai_api_base=None, assistant_output_color="blue", ): - if not openai_api_key: - raise MissingAPIKeyError("No OpenAI API key provided.") - openai.api_key = openai_api_key - openai.api_base = openai_api_base - self.verbose = verbose self.abs_fnames = set() self.cur_messages = [] @@ -91,18 +91,7 @@ class Coder: else: self.console = Console(force_terminal=True, no_color=True) - main_model = models.Model(main_model) - if not main_model.always_available: - if not self.check_model_availability(main_model): - if main_model != models.GPT4: - self.io.tool_error( - f"API key does not support {main_model.name}, falling back to" - f" {models.GPT35_16k.name}" - ) - main_model = models.GPT35_16k - self.main_model = main_model - self.edit_format = self.main_model.edit_format self.io.tool_output(f"Model: {main_model.name}") @@ -687,3 +676,9 @@ class Coder: print() traceback.print_exc() return None, err + + +def check_model_availability(main_model): + available_models = openai.Model.list() + model_ids = [model.id for model in available_models["data"]] + return main_model.name in model_ids diff --git a/aider/main.py b/aider/main.py index 169d747f6..161406858 100644 --- a/aider/main.py +++ b/aider/main.py @@ -227,10 +227,14 @@ def main(args=None, input=None, output=None): io.tool_error("No OpenAI API key provided. Use --openai-api-key or env OPENAI_API_KEY.") return 1 + main_model = models.Model(args.model) + coder = Coder.create( - "diff", + main_model, io, - main_model=args.model, + args.openai_api_key, + args.openai_api_base, + ## fnames=args.files, pretty=args.pretty, show_diffs=args.show_diffs, @@ -239,8 +243,6 @@ def main(args=None, input=None, output=None): dry_run=args.dry_run, map_tokens=args.map_tokens, verbose=args.verbose, - openai_api_key=args.openai_api_key, - openai_api_base=args.openai_api_base, assistant_output_color=args.assistant_output_color, )