Refactored the code to use a global variable to store the OpenAI instance and added a helper function to determine the edit format for a given model.

This commit is contained in:
JV 2023-08-15 04:27:53 +12:00 committed by Joshua Vial
parent a0d6efc13c
commit 22f3349862
2 changed files with 22 additions and 15 deletions

View file

@ -1,6 +1,6 @@
import importlib
using_openrouter = False
saved_openai = None
class Model:
name = None
@ -16,22 +16,15 @@ class Model:
completion_price = None
def __init__(self, name, openai=None):
global using_openrouter
global saved_openai
if (openai and "openrouter.ai" in openai.api_base):
using_openrouter = True
saved_openai = openai
from .openai import OpenAIModel
from .openrouter import OpenRouterModel
model = None
if using_openrouter:
if name == 'gpt-4':
name = 'openai/gpt-4'
elif name == 'gpt-3.5-turbo':
name = 'openai/gpt-3.5-turbo'
elif name == 'gpt-3.5.turbo-16k':
name = 'openai/gpt-3.5-turbo-16k'
model = OpenRouterModel(name, openai)
if saved_openai:
model = OpenRouterModel(name, saved_openai)
else:
model = OpenAIModel(name)

View file

@ -4,9 +4,16 @@ from .model import Model
class OpenRouterModel(Model):
def __init__(self, name, openai):
if name == 'gpt-4':
name = 'openai/gpt-4'
elif name == 'gpt-3.5-turbo':
name = 'openai/gpt-3.5-turbo'
elif name == 'gpt-3.5-turbo-16k':
name = 'openai/gpt-3.5-turbo-16k'
self.name = name
self.edit_format = "diff"
self.use_repo_map = True
self.edit_format = edit_format_for_model(name)
self.use_repo_map = self.edit_format == "diff"
# TODO: figure out proper encodings for non openai models
self.tokenizer = tiktoken.get_encoding("cl100k_base")
@ -20,4 +27,11 @@ class OpenRouterModel(Model):
self.completion_price = float(found.get('pricing').get('completion')) * 1000
else:
raise ValueError('invalid openrouter model for {name}')
raise ValueError(f'invalid openrouter model: {name}')
def edit_format_for_model(name):
if any(str in name for str in ['gpt-4', 'claude-2']):
return "diff"
return "whole"