merge in openai upgrade

This commit is contained in:
Joshua Vial 2023-12-11 20:43:18 +13:00
commit fe9423d7b8
23 changed files with 304 additions and 201 deletions

View file

@ -42,6 +42,7 @@ def wrap_fence(name):
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
class Coder:
client = None
abs_fnames = None
repo = None
last_aider_commit_hash = None
@ -58,6 +59,7 @@ class Coder:
main_model=None,
edit_format=None,
io=None,
client=None,
skip_model_availabily_check=False,
**kwargs,
):
@ -67,26 +69,28 @@ class Coder:
main_model = models.GPT4
if not skip_model_availabily_check and not main_model.always_available:
if not check_model_availability(io, main_model):
if not check_model_availability(io, client, main_model):
fallback_model = models.GPT35_1106
if main_model != models.GPT4:
io.tool_error(
f"API key does not support {main_model.name}, falling back to"
f" {models.GPT35_16k.name}"
f" {fallback_model.name}"
)
main_model = models.GPT35_16k
main_model = fallback_model
if edit_format is None:
edit_format = main_model.edit_format
if edit_format == "diff":
return EditBlockCoder(main_model, io, **kwargs)
return EditBlockCoder(client, main_model, io, **kwargs)
elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs)
return WholeFileCoder(client, main_model, io, **kwargs)
else:
raise ValueError(f"Unknown edit format {edit_format}")
def __init__(
self,
client,
main_model,
io,
fnames=None,
@ -105,6 +109,8 @@ class Coder:
voice_language=None,
aider_ignore_file=None,
):
self.client = client
if not fnames:
fnames = []
@ -161,7 +167,9 @@ class Coder:
if use_git:
try:
self.repo = GitRepo(self.io, fnames, git_dname, aider_ignore_file)
self.repo = GitRepo(
self.io, fnames, git_dname, aider_ignore_file, client=self.client
)
self.root = self.repo.root
except FileNotFoundError:
self.repo = None
@ -192,6 +200,7 @@ class Coder:
self.io.tool_output(f"Added {fname} to the chat.")
self.summarizer = ChatSummary(
self.client,
models.Model.weak_model(),
self.main_model.max_chat_history_tokens,
)
@ -305,6 +314,13 @@ class Coder:
def get_files_messages(self):
all_content = ""
repo_content = self.get_repo_map()
if repo_content:
if all_content:
all_content += "\n"
all_content += repo_content
if self.abs_fnames:
files_content = self.gpt_prompts.files_content_prefix
files_content += self.get_files_content()
@ -313,12 +329,6 @@ class Coder:
all_content += files_content
repo_content = self.get_repo_map()
if repo_content:
if all_content:
all_content += "\n"
all_content += repo_content
files_messages = [
dict(role="user", content=all_content),
dict(role="assistant", content="Ok."),
@ -500,7 +510,7 @@ class Coder:
interrupted = self.send(messages, functions=self.functions)
except ExhaustedContextWindow:
exhausted = True
except openai.error.InvalidRequestError as err:
except openai.BadRequestError as err:
if "maximum context length" in str(err):
exhausted = True
else:
@ -617,7 +627,9 @@ class Coder:
interrupted = False
try:
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
hash_object, completion = send_with_retries(
self.client, model, messages, functions, self.stream
)
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream:
@ -971,9 +983,16 @@ class Coder:
return True
def check_model_availability(io, main_model):
available_models = openai.Model.list()
model_ids = sorted(model.id for model in available_models["data"])
def check_model_availability(io, client, main_model):
try:
available_models = client.models.list()
except openai.NotFoundError:
# Azure sometimes returns 404?
# https://discord.com/channels/1131200896827654144/1182327371232186459
io.tool_error("Unable to list available models, proceeding with {main_model.name}")
return True
model_ids = sorted(model.id for model in available_models)
if main_model.name in model_ids:
return True