cleaned up Model init

This commit is contained in:
Paul Gauthier 2023-06-15 06:26:16 -07:00
parent bd91cb7fb7
commit 8fb4ab2be3
2 changed files with 6 additions and 10 deletions

View file

@ -78,7 +78,7 @@ class Coder:
else: else:
self.console = Console(force_terminal=True, no_color=True) self.console = Console(force_terminal=True, no_color=True)
main_model = models.get_model(main_model) main_model = models.Model(main_model)
if not main_model.is_always_available(): if not main_model.is_always_available():
if not self.check_model_availability(main_model): if not self.check_model_availability(main_model):
if main_model != models.GPT4: if main_model != models.GPT4:

View file

@ -13,6 +13,11 @@ class Model:
self.max_context_tokens = tokens * 1024 self.max_context_tokens = tokens * 1024
if self.is_gpt4() or self.is_gpt35():
return
raise ValueError(f"Unsupported model: {name}")
def is_gpt4(self): def is_gpt4(self):
return self.name.startswith("gpt-4") return self.name.startswith("gpt-4")
@ -26,12 +31,3 @@ class Model:
GPT4 = Model("gpt-4", 8) GPT4 = Model("gpt-4", 8)
GPT35 = Model("gpt-3.5-turbo") GPT35 = Model("gpt-3.5-turbo")
GPT35_16k = Model("gpt-3.5-turbo-16k") GPT35_16k = Model("gpt-3.5-turbo-16k")
def get_model(name):
model = Model(name)
if model.is_gpt4() or model.is_gpt35():
return model
raise ValueError(f"Unsupported model: {name}")