mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-06 04:35:00 +00:00
Refactored the code to use a global variable to store the OpenAI instance and added a helper function to determine the edit format for a given model.
This commit is contained in:
parent
a0d6efc13c
commit
22f3349862
2 changed files with 22 additions and 15 deletions
|
@ -1,6 +1,6 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
using_openrouter = False
|
saved_openai = None
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
name = None
|
name = None
|
||||||
|
@ -16,22 +16,15 @@ class Model:
|
||||||
completion_price = None
|
completion_price = None
|
||||||
|
|
||||||
def __init__(self, name, openai=None):
|
def __init__(self, name, openai=None):
|
||||||
global using_openrouter
|
global saved_openai
|
||||||
if (openai and "openrouter.ai" in openai.api_base):
|
if (openai and "openrouter.ai" in openai.api_base):
|
||||||
using_openrouter = True
|
saved_openai = openai
|
||||||
|
|
||||||
from .openai import OpenAIModel
|
from .openai import OpenAIModel
|
||||||
from .openrouter import OpenRouterModel
|
from .openrouter import OpenRouterModel
|
||||||
model = None
|
model = None
|
||||||
if using_openrouter:
|
if saved_openai:
|
||||||
if name == 'gpt-4':
|
model = OpenRouterModel(name, saved_openai)
|
||||||
name = 'openai/gpt-4'
|
|
||||||
elif name == 'gpt-3.5-turbo':
|
|
||||||
name = 'openai/gpt-3.5-turbo'
|
|
||||||
elif name == 'gpt-3.5.turbo-16k':
|
|
||||||
name = 'openai/gpt-3.5-turbo-16k'
|
|
||||||
|
|
||||||
model = OpenRouterModel(name, openai)
|
|
||||||
else:
|
else:
|
||||||
model = OpenAIModel(name)
|
model = OpenAIModel(name)
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,16 @@ from .model import Model
|
||||||
|
|
||||||
class OpenRouterModel(Model):
|
class OpenRouterModel(Model):
|
||||||
def __init__(self, name, openai):
|
def __init__(self, name, openai):
|
||||||
|
if name == 'gpt-4':
|
||||||
|
name = 'openai/gpt-4'
|
||||||
|
elif name == 'gpt-3.5-turbo':
|
||||||
|
name = 'openai/gpt-3.5-turbo'
|
||||||
|
elif name == 'gpt-3.5-turbo-16k':
|
||||||
|
name = 'openai/gpt-3.5-turbo-16k'
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.edit_format = "diff"
|
self.edit_format = edit_format_for_model(name)
|
||||||
self.use_repo_map = True
|
self.use_repo_map = self.edit_format == "diff"
|
||||||
|
|
||||||
# TODO: figure out proper encodings for non openai models
|
# TODO: figure out proper encodings for non openai models
|
||||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
@ -20,4 +27,11 @@ class OpenRouterModel(Model):
|
||||||
self.completion_price = float(found.get('pricing').get('completion')) * 1000
|
self.completion_price = float(found.get('pricing').get('completion')) * 1000
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid openrouter model for {name}')
|
raise ValueError(f'invalid openrouter model: {name}')
|
||||||
|
|
||||||
|
|
||||||
|
def edit_format_for_model(name):
|
||||||
|
if any(str in name for str in ['gpt-4', 'claude-2']):
|
||||||
|
return "diff"
|
||||||
|
|
||||||
|
return "whole"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue