mirror of
https://github.com/Aider-AI/aider.git
synced 2025-06-12 07:35:00 +00:00
roughed in openai 1.x
This commit is contained in:
parent
fd34766aa9
commit
6ebc142377
15 changed files with 136 additions and 110 deletions
|
@ -1,7 +1,5 @@
|
|||
import json
|
||||
|
||||
import openai
|
||||
|
||||
|
||||
class Model:
|
||||
name = None
|
||||
|
@ -18,12 +16,12 @@ class Model:
|
|||
completion_price = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, name):
|
||||
def create(cls, name, client=None):
|
||||
from .openai import OpenAIModel
|
||||
from .openrouter import OpenRouterModel
|
||||
|
||||
if "openrouter.ai" in openai.api_base:
|
||||
return OpenRouterModel(name)
|
||||
if client and client.base_url.host == "openrouter.ai":
|
||||
return OpenRouterModel(client, name)
|
||||
return OpenAIModel(name)
|
||||
|
||||
def __str__(self):
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import openai
|
||||
import tiktoken
|
||||
|
||||
from .model import Model
|
||||
|
@ -7,7 +6,7 @@ cached_model_details = None
|
|||
|
||||
|
||||
class OpenRouterModel(Model):
|
||||
def __init__(self, name):
|
||||
def __init__(self, client, name):
|
||||
if name == "gpt-4":
|
||||
name = "openai/gpt-4"
|
||||
elif name == "gpt-3.5-turbo":
|
||||
|
@ -24,7 +23,7 @@ class OpenRouterModel(Model):
|
|||
|
||||
global cached_model_details
|
||||
if cached_model_details is None:
|
||||
cached_model_details = openai.Model.list().data
|
||||
cached_model_details = client.models.list().data
|
||||
found = next(
|
||||
(details for details in cached_model_details if details.get("id") == name), None
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue