mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-05 04:05:04 +00:00
Refactored the code to use a global variable to store the OpenAI instance and added a helper function to determine the edit format for a given model.
This commit is contained in:
parent
a0d6efc13c
commit
22f3349862
2 changed files with 22 additions and 15 deletions
|
@ -1,6 +1,6 @@
|
|||
import importlib
|
||||
|
||||
using_openrouter = False
|
||||
saved_openai = None
|
||||
|
||||
class Model:
|
||||
name = None
|
||||
|
@ -16,22 +16,15 @@ class Model:
|
|||
completion_price = None
|
||||
|
||||
def __init__(self, name, openai=None):
|
||||
global using_openrouter
|
||||
global saved_openai
|
||||
if (openai and "openrouter.ai" in openai.api_base):
|
||||
using_openrouter = True
|
||||
saved_openai = openai
|
||||
|
||||
from .openai import OpenAIModel
|
||||
from .openrouter import OpenRouterModel
|
||||
model = None
|
||||
if using_openrouter:
|
||||
if name == 'gpt-4':
|
||||
name = 'openai/gpt-4'
|
||||
elif name == 'gpt-3.5-turbo':
|
||||
name = 'openai/gpt-3.5-turbo'
|
||||
elif name == 'gpt-3.5.turbo-16k':
|
||||
name = 'openai/gpt-3.5-turbo-16k'
|
||||
|
||||
model = OpenRouterModel(name, openai)
|
||||
if saved_openai:
|
||||
model = OpenRouterModel(name, saved_openai)
|
||||
else:
|
||||
model = OpenAIModel(name)
|
||||
|
||||
|
|
|
@ -4,9 +4,16 @@ from .model import Model
|
|||
|
||||
class OpenRouterModel(Model):
|
||||
def __init__(self, name, openai):
|
||||
if name == 'gpt-4':
|
||||
name = 'openai/gpt-4'
|
||||
elif name == 'gpt-3.5-turbo':
|
||||
name = 'openai/gpt-3.5-turbo'
|
||||
elif name == 'gpt-3.5-turbo-16k':
|
||||
name = 'openai/gpt-3.5-turbo-16k'
|
||||
|
||||
self.name = name
|
||||
self.edit_format = "diff"
|
||||
self.use_repo_map = True
|
||||
self.edit_format = edit_format_for_model(name)
|
||||
self.use_repo_map = self.edit_format == "diff"
|
||||
|
||||
# TODO: figure out proper encodings for non openai models
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
@ -20,4 +27,11 @@ class OpenRouterModel(Model):
|
|||
self.completion_price = float(found.get('pricing').get('completion')) * 1000
|
||||
|
||||
else:
|
||||
raise ValueError('invalid openrouter model for {name}')
|
||||
raise ValueError(f'invalid openrouter model: {name}')
|
||||
|
||||
|
||||
def edit_format_for_model(name):
|
||||
if any(str in name for str in ['gpt-4', 'claude-2']):
|
||||
return "diff"
|
||||
|
||||
return "whole"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue