roughed in openai 1.x

This commit is contained in:
Paul Gauthier 2023-12-05 07:37:05 -08:00
parent fd34766aa9
commit 6ebc142377
15 changed files with 136 additions and 110 deletions

View file

@ -53,6 +53,7 @@ class Coder:
@classmethod
def create(
self,
client,
main_model=None,
edit_format=None,
io=None,
@ -65,7 +66,7 @@ class Coder:
main_model = models.GPT4
if not skip_model_availabily_check and not main_model.always_available:
if not check_model_availability(io, main_model):
if not check_model_availability(io, client, main_model):
if main_model != models.GPT4:
io.tool_error(
f"API key does not support {main_model.name}, falling back to"
@ -77,14 +78,15 @@ class Coder:
edit_format = main_model.edit_format
if edit_format == "diff":
return EditBlockCoder(main_model, io, **kwargs)
return EditBlockCoder(client, main_model, io, **kwargs)
elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs)
return WholeFileCoder(client, main_model, io, **kwargs)
else:
raise ValueError(f"Unknown edit format {edit_format}")
def __init__(
self,
client,
main_model,
io,
fnames=None,
@ -103,6 +105,8 @@ class Coder:
voice_language=None,
aider_ignore_file=None,
):
self.client = client
if not fnames:
fnames = []
@ -190,6 +194,7 @@ class Coder:
self.io.tool_output(f"Added {fname} to the chat.")
self.summarizer = ChatSummary(
self.client,
models.Model.weak_model(),
self.main_model.max_chat_history_tokens,
)
@ -470,7 +475,7 @@ class Coder:
interrupted = self.send(messages, functions=self.functions)
except ExhaustedContextWindow:
exhausted = True
except openai.error.InvalidRequestError as err:
except openai.BadRequestError as err:
if "maximum context length" in str(err):
exhausted = True
else:
@ -587,7 +592,9 @@ class Coder:
interrupted = False
try:
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
hash_object, completion = send_with_retries(
self.client, model, messages, functions, self.stream
)
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream:
@ -941,8 +948,8 @@ class Coder:
return True
def check_model_availability(io, main_model):
available_models = openai.Model.list()
def check_model_availability(io, client, main_model):
available_models = client.models.list()
model_ids = sorted(model.id for model in available_models["data"])
if main_model.name in model_ids:
return True