From d91707f73ee5345aa4c5e5e667612363632d6870 Mon Sep 17 00:00:00 2001 From: Krazer Date: Tue, 11 Jun 2024 13:07:19 -0500 Subject: [PATCH] all models to be registered with litellm --- aider/args.py | 17 +++++++++++++++++ aider/main.py | 10 ++++++++++ aider/models.py | 12 ++++++++++++ 3 files changed, 39 insertions(+) diff --git a/aider/args.py b/aider/args.py index 35e49b1e8..ba931f42d 100644 --- a/aider/args.py +++ b/aider/args.py @@ -135,6 +135,23 @@ def get_parser(default_config_files, git_root): env_var="OPENAI_ORGANIZATION_ID", help="Specify the OpenAI organization ID", ) + group.add_argument( + "--model_file", + metavar="MODEL_FILE", + default=None, + help={ + "File with model definitions to be registered for info/cost, json formated", + " {" + " \"gpt-4\": {", + " \"max_tokens\": 8192,", + " \"input_cost_per_token\": 0.00003,", + " \"output_cost_per_token\": 0.00006,", + " \"litellm_provider\": \"openai\",", + " \"mode\": \"chat\"", + " },", + " }" + } + ) group.add_argument( "--edit-format", metavar="EDIT_FORMAT", diff --git a/aider/main.py b/aider/main.py index 9cf0dda10..12ea6d6df 100644 --- a/aider/main.py +++ b/aider/main.py @@ -332,6 +332,16 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F if args.openai_organization_id: os.environ["OPENAI_ORGANIZATION"] = args.openai_organization_id + model_def_files = [] + model_def_fname = Path(".aider.models.json") + model_def_files.append(Path.home() / model_def_fname) # homedir + if git_root: + model_def_files.append(Path(git_root) / model_def_fname) # git root + if args.models: + model_def_files.append(args.models) + model_def_files = list(map(str, model_def_files)) + models.register_models(model_def_files) + main_model = models.Model(args.model, weak_model=args.weak_model) lint_cmds = parse_lint_cmds(args.lint_cmd, io) diff --git a/aider/models.py b/aider/models.py index 1d51c10c7..1a96bb5b3 100644 --- a/aider/models.py +++ b/aider/models.py @@ -426,6 +426,18 @@ class Model: return res +def register_models(model_def_fnames): + for model_def_fname in model_def_fnames: + if not os.path.exists(model_def_fname): + continue + print(f"Registering model definition from {model_def_fname}") + try: + with open(model_def_fname, "r") as model_def_file: + model_def = json.load(model_def_file) + except json.JSONDecodeError as e: + print(f"Error opening/decoding model definition: {e}") + + litellm.register_model(model_def) def validate_variables(vars): missing = []