cleaning up openrouter code

This commit is contained in:
Joshua Vial 2023-08-23 22:03:09 +12:00
parent 668a0500ff
commit 0826e116da
3 changed files with 5 additions and 5 deletions

View file

@ -11,15 +11,14 @@ class Model:
prompt_price = None prompt_price = None
completion_price = None completion_price = None
openai=None
@classmethod @classmethod
def create(cls, name, **kwargs): def create(cls, name):
from .openai import OpenAIModel from .openai import OpenAIModel
from .openrouter import OpenRouterModel from .openrouter import OpenRouterModel
if ("openrouter.ai" in openai.api_base): if ("openrouter.ai" in openai.api_base):
return OpenRouterModel(name, **kwargs) return OpenRouterModel(name)
return OpenAIModel(name, **kwargs) return OpenAIModel(name)
def __str__(self): def __str__(self):
return self.name return self.name

View file

@ -21,7 +21,6 @@ class OpenRouterModel(Model):
# TODO: figure out proper encodings for non openai models # TODO: figure out proper encodings for non openai models
self.tokenizer = tiktoken.get_encoding("cl100k_base") self.tokenizer = tiktoken.get_encoding("cl100k_base")
# TODO cache the model list data to speed up using multiple models
global cached_model_details global cached_model_details
if cached_model_details == None: if cached_model_details == None:
cached_model_details = openai.Model.list().data cached_model_details = openai.Model.list().data

View file

@ -28,6 +28,7 @@ class TestModels(unittest.TestCase):
@patch('openai.Model.list') @patch('openai.Model.list')
def test_openrouter_model_properties(self, mock_model_list): def test_openrouter_model_properties(self, mock_model_list):
import openai import openai
old_base = openai.api_base
openai.api_base = 'https://openrouter.ai/api/v1' openai.api_base = 'https://openrouter.ai/api/v1'
mock_model_list.return_value = { mock_model_list.return_value = {
'data': [ 'data': [
@ -49,6 +50,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(model.max_context_tokens, 8192) self.assertEqual(model.max_context_tokens, 8192)
self.assertEqual(model.prompt_price, 0.06) self.assertEqual(model.prompt_price, 0.06)
self.assertEqual(model.completion_price, 0.12) self.assertEqual(model.completion_price, 0.12)
openai.api_base = old_base
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()