mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-25 14:55:00 +00:00
cleaning up openrouter code
This commit is contained in:
parent
668a0500ff
commit
0826e116da
3 changed files with 5 additions and 5 deletions
|
@ -11,15 +11,14 @@ class Model:
|
||||||
|
|
||||||
prompt_price = None
|
prompt_price = None
|
||||||
completion_price = None
|
completion_price = None
|
||||||
openai=None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, name, **kwargs):
|
def create(cls, name):
|
||||||
from .openai import OpenAIModel
|
from .openai import OpenAIModel
|
||||||
from .openrouter import OpenRouterModel
|
from .openrouter import OpenRouterModel
|
||||||
if ("openrouter.ai" in openai.api_base):
|
if ("openrouter.ai" in openai.api_base):
|
||||||
return OpenRouterModel(name, **kwargs)
|
return OpenRouterModel(name)
|
||||||
return OpenAIModel(name, **kwargs)
|
return OpenAIModel(name)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
|
@ -21,7 +21,6 @@ class OpenRouterModel(Model):
|
||||||
# TODO: figure out proper encodings for non openai models
|
# TODO: figure out proper encodings for non openai models
|
||||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
# TODO cache the model list data to speed up using multiple models
|
|
||||||
global cached_model_details
|
global cached_model_details
|
||||||
if cached_model_details == None:
|
if cached_model_details == None:
|
||||||
cached_model_details = openai.Model.list().data
|
cached_model_details = openai.Model.list().data
|
||||||
|
|
|
@ -28,6 +28,7 @@ class TestModels(unittest.TestCase):
|
||||||
@patch('openai.Model.list')
|
@patch('openai.Model.list')
|
||||||
def test_openrouter_model_properties(self, mock_model_list):
|
def test_openrouter_model_properties(self, mock_model_list):
|
||||||
import openai
|
import openai
|
||||||
|
old_base = openai.api_base
|
||||||
openai.api_base = 'https://openrouter.ai/api/v1'
|
openai.api_base = 'https://openrouter.ai/api/v1'
|
||||||
mock_model_list.return_value = {
|
mock_model_list.return_value = {
|
||||||
'data': [
|
'data': [
|
||||||
|
@ -49,6 +50,7 @@ class TestModels(unittest.TestCase):
|
||||||
self.assertEqual(model.max_context_tokens, 8192)
|
self.assertEqual(model.max_context_tokens, 8192)
|
||||||
self.assertEqual(model.prompt_price, 0.06)
|
self.assertEqual(model.prompt_price, 0.06)
|
||||||
self.assertEqual(model.completion_price, 0.12)
|
self.assertEqual(model.completion_price, 0.12)
|
||||||
|
openai.api_base = old_base
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue