From 22f33498628fdd91eff22d5bd3a28c8f14d429d9 Mon Sep 17 00:00:00 2001 From: JV Date: Tue, 15 Aug 2023 04:27:53 +1200 Subject: [PATCH] Refactored the code to use a global variable to store the OpenAI instance and added a helper function to determine the edit format for a given model. --- aider/models/model.py | 17 +++++------------ aider/models/openrouter.py | 20 +++++++++++++++++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/aider/models/model.py b/aider/models/model.py index 0c0ebafaf..4b3397a50 100644 --- a/aider/models/model.py +++ b/aider/models/model.py @@ -1,6 +1,6 @@ import importlib -using_openrouter = False +saved_openai = None class Model: name = None @@ -16,22 +16,15 @@ class Model: completion_price = None def __init__(self, name, openai=None): - global using_openrouter + global saved_openai if (openai and "openrouter.ai" in openai.api_base): - using_openrouter = True + saved_openai = openai from .openai import OpenAIModel from .openrouter import OpenRouterModel model = None - if using_openrouter: - if name == 'gpt-4': - name = 'openai/gpt-4' - elif name == 'gpt-3.5-turbo': - name = 'openai/gpt-3.5-turbo' - elif name == 'gpt-3.5.turbo-16k': - name = 'openai/gpt-3.5-turbo-16k' - - model = OpenRouterModel(name, openai) + if saved_openai: + model = OpenRouterModel(name, saved_openai) else: model = OpenAIModel(name) diff --git a/aider/models/openrouter.py b/aider/models/openrouter.py index 11e0bc323..7250a521a 100644 --- a/aider/models/openrouter.py +++ b/aider/models/openrouter.py @@ -4,9 +4,16 @@ from .model import Model class OpenRouterModel(Model): def __init__(self, name, openai): + if name == 'gpt-4': + name = 'openai/gpt-4' + elif name == 'gpt-3.5-turbo': + name = 'openai/gpt-3.5-turbo' + elif name == 'gpt-3.5-turbo-16k': + name = 'openai/gpt-3.5-turbo-16k' + self.name = name - self.edit_format = "diff" - self.use_repo_map = True + self.edit_format = edit_format_for_model(name) + self.use_repo_map = self.edit_format == "diff" # TODO: figure out proper encodings for non openai models self.tokenizer = tiktoken.get_encoding("cl100k_base") @@ -20,4 +27,11 @@ class OpenRouterModel(Model): self.completion_price = float(found.get('pricing').get('completion')) * 1000 else: - raise ValueError('invalid openrouter model for {name}') + raise ValueError(f'invalid openrouter model: {name}') + + +def edit_format_for_model(name): + if any(str in name for str in ['gpt-4', 'claude-2']): + return "diff" + + return "whole"