mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 14:25:00 +00:00
merge in openai upgrade
This commit is contained in:
commit
fe9423d7b8
23 changed files with 304 additions and 201 deletions
|
@ -42,6 +42,7 @@ def wrap_fence(name):
|
|||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
|
||||
class Coder:
|
||||
client = None
|
||||
abs_fnames = None
|
||||
repo = None
|
||||
last_aider_commit_hash = None
|
||||
|
@ -58,6 +59,7 @@ class Coder:
|
|||
main_model=None,
|
||||
edit_format=None,
|
||||
io=None,
|
||||
client=None,
|
||||
skip_model_availabily_check=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -67,26 +69,28 @@ class Coder:
|
|||
main_model = models.GPT4
|
||||
|
||||
if not skip_model_availabily_check and not main_model.always_available:
|
||||
if not check_model_availability(io, main_model):
|
||||
if not check_model_availability(io, client, main_model):
|
||||
fallback_model = models.GPT35_1106
|
||||
if main_model != models.GPT4:
|
||||
io.tool_error(
|
||||
f"API key does not support {main_model.name}, falling back to"
|
||||
f" {models.GPT35_16k.name}"
|
||||
f" {fallback_model.name}"
|
||||
)
|
||||
main_model = models.GPT35_16k
|
||||
main_model = fallback_model
|
||||
|
||||
if edit_format is None:
|
||||
edit_format = main_model.edit_format
|
||||
|
||||
if edit_format == "diff":
|
||||
return EditBlockCoder(main_model, io, **kwargs)
|
||||
return EditBlockCoder(client, main_model, io, **kwargs)
|
||||
elif edit_format == "whole":
|
||||
return WholeFileCoder(main_model, io, **kwargs)
|
||||
return WholeFileCoder(client, main_model, io, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown edit format {edit_format}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client,
|
||||
main_model,
|
||||
io,
|
||||
fnames=None,
|
||||
|
@ -105,6 +109,8 @@ class Coder:
|
|||
voice_language=None,
|
||||
aider_ignore_file=None,
|
||||
):
|
||||
self.client = client
|
||||
|
||||
if not fnames:
|
||||
fnames = []
|
||||
|
||||
|
@ -161,7 +167,9 @@ class Coder:
|
|||
|
||||
if use_git:
|
||||
try:
|
||||
self.repo = GitRepo(self.io, fnames, git_dname, aider_ignore_file)
|
||||
self.repo = GitRepo(
|
||||
self.io, fnames, git_dname, aider_ignore_file, client=self.client
|
||||
)
|
||||
self.root = self.repo.root
|
||||
except FileNotFoundError:
|
||||
self.repo = None
|
||||
|
@ -192,6 +200,7 @@ class Coder:
|
|||
self.io.tool_output(f"Added {fname} to the chat.")
|
||||
|
||||
self.summarizer = ChatSummary(
|
||||
self.client,
|
||||
models.Model.weak_model(),
|
||||
self.main_model.max_chat_history_tokens,
|
||||
)
|
||||
|
@ -305,6 +314,13 @@ class Coder:
|
|||
|
||||
def get_files_messages(self):
|
||||
all_content = ""
|
||||
|
||||
repo_content = self.get_repo_map()
|
||||
if repo_content:
|
||||
if all_content:
|
||||
all_content += "\n"
|
||||
all_content += repo_content
|
||||
|
||||
if self.abs_fnames:
|
||||
files_content = self.gpt_prompts.files_content_prefix
|
||||
files_content += self.get_files_content()
|
||||
|
@ -313,12 +329,6 @@ class Coder:
|
|||
|
||||
all_content += files_content
|
||||
|
||||
repo_content = self.get_repo_map()
|
||||
if repo_content:
|
||||
if all_content:
|
||||
all_content += "\n"
|
||||
all_content += repo_content
|
||||
|
||||
files_messages = [
|
||||
dict(role="user", content=all_content),
|
||||
dict(role="assistant", content="Ok."),
|
||||
|
@ -500,7 +510,7 @@ class Coder:
|
|||
interrupted = self.send(messages, functions=self.functions)
|
||||
except ExhaustedContextWindow:
|
||||
exhausted = True
|
||||
except openai.error.InvalidRequestError as err:
|
||||
except openai.BadRequestError as err:
|
||||
if "maximum context length" in str(err):
|
||||
exhausted = True
|
||||
else:
|
||||
|
@ -617,7 +627,9 @@ class Coder:
|
|||
|
||||
interrupted = False
|
||||
try:
|
||||
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
|
||||
hash_object, completion = send_with_retries(
|
||||
self.client, model, messages, functions, self.stream
|
||||
)
|
||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||
|
||||
if self.stream:
|
||||
|
@ -971,9 +983,16 @@ class Coder:
|
|||
return True
|
||||
|
||||
|
||||
def check_model_availability(io, main_model):
|
||||
available_models = openai.Model.list()
|
||||
model_ids = sorted(model.id for model in available_models["data"])
|
||||
def check_model_availability(io, client, main_model):
|
||||
try:
|
||||
available_models = client.models.list()
|
||||
except openai.NotFoundError:
|
||||
# Azure sometimes returns 404?
|
||||
# https://discord.com/channels/1131200896827654144/1182327371232186459
|
||||
io.tool_error("Unable to list available models, proceeding with {main_model.name}")
|
||||
return True
|
||||
|
||||
model_ids = sorted(model.id for model in available_models)
|
||||
if main_model.name in model_ids:
|
||||
return True
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue