roughed in openai 1.x

This commit is contained in:
Paul Gauthier 2023-12-05 07:37:05 -08:00
parent fd34766aa9
commit 6ebc142377
15 changed files with 136 additions and 110 deletions

View file

@ -53,6 +53,7 @@ class Coder:
@classmethod
def create(
self,
client,
main_model=None,
edit_format=None,
io=None,
@ -65,7 +66,7 @@ class Coder:
main_model = models.GPT4
if not skip_model_availabily_check and not main_model.always_available:
if not check_model_availability(io, main_model):
if not check_model_availability(io, client, main_model):
if main_model != models.GPT4:
io.tool_error(
f"API key does not support {main_model.name}, falling back to"
@ -77,14 +78,15 @@ class Coder:
edit_format = main_model.edit_format
if edit_format == "diff":
return EditBlockCoder(main_model, io, **kwargs)
return EditBlockCoder(client, main_model, io, **kwargs)
elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs)
return WholeFileCoder(client, main_model, io, **kwargs)
else:
raise ValueError(f"Unknown edit format {edit_format}")
def __init__(
self,
client,
main_model,
io,
fnames=None,
@ -103,6 +105,8 @@ class Coder:
voice_language=None,
aider_ignore_file=None,
):
self.client = client
if not fnames:
fnames = []
@ -190,6 +194,7 @@ class Coder:
self.io.tool_output(f"Added {fname} to the chat.")
self.summarizer = ChatSummary(
self.client,
models.Model.weak_model(),
self.main_model.max_chat_history_tokens,
)
@ -470,7 +475,7 @@ class Coder:
interrupted = self.send(messages, functions=self.functions)
except ExhaustedContextWindow:
exhausted = True
except openai.error.InvalidRequestError as err:
except openai.BadRequestError as err:
if "maximum context length" in str(err):
exhausted = True
else:
@ -587,7 +592,9 @@ class Coder:
interrupted = False
try:
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
hash_object, completion = send_with_retries(
self.client, model, messages, functions, self.stream
)
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream:
@ -941,8 +948,8 @@ class Coder:
return True
def check_model_availability(io, main_model):
available_models = openai.Model.list()
def check_model_availability(io, client, main_model):
available_models = client.models.list()
model_ids = sorted(model.id for model in available_models["data"])
if main_model.name in model_ids:
return True

View file

@ -462,7 +462,7 @@ class Commands:
if not self.voice:
try:
self.voice = voice.Voice()
self.voice = voice.Voice(self.coder.client)
except voice.SoundDeviceError:
self.io.tool_error(
"Unable to import `sounddevice` and/or `soundfile`, is portaudio installed?"

View file

@ -7,7 +7,8 @@ from aider.sendchat import simple_send_with_retries
class ChatSummary:
def __init__(self, model=models.Model.weak_model(), max_tokens=1024):
def __init__(self, client, model=models.Model.weak_model(), max_tokens=1024):
self.client = client
self.tokenizer = model.tokenizer
self.max_tokens = max_tokens
self.model = model
@ -84,7 +85,7 @@ class ChatSummary:
dict(role="user", content=content),
]
summary = simple_send_with_retries(self.model.name, messages)
summary = simple_send_with_retries(self.client, self.model.name, messages)
if summary is None:
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
summary = prompts.summary_prefix + summary

View file

@ -176,27 +176,23 @@ def main(argv=None, input=None, output=None, force_git_root=None):
model_group.add_argument(
"--openai-api-base",
metavar="OPENAI_API_BASE",
help="Specify the openai.api_base (default: https://api.openai.com/v1)",
help="Specify the api_base (default: https://api.openai.com/v1)",
)
model_group.add_argument(
"--openai-api-type",
metavar="OPENAI_API_TYPE",
help="Specify the openai.api_type",
help="Specify the api_type",
)
model_group.add_argument(
"--openai-api-version",
metavar="OPENAI_API_VERSION",
help="Specify the openai.api_version",
help="Specify the api_version",
)
# TODO: use deployment_id
model_group.add_argument(
"--openai-api-deployment-id",
metavar="OPENAI_API_DEPLOYMENT_ID",
help="Specify the deployment_id arg to be passed to openai.ChatCompletion.create()",
)
model_group.add_argument(
"--openai-api-engine",
metavar="OPENAI_API_ENGINE",
help="Specify the engine arg to be passed to openai.ChatCompletion.create()",
help="Specify the deployment_id",
)
model_group.add_argument(
"--edit-format",
@ -492,19 +488,28 @@ def main(argv=None, input=None, output=None, force_git_root=None):
)
return 1
openai.api_key = args.openai_api_key
for attr in ("base", "type", "version", "deployment_id", "engine"):
arg_key = f"openai_api_{attr}"
val = getattr(args, arg_key)
if val is not None:
mod_key = f"api_{attr}"
setattr(openai, mod_key, val)
io.tool_output(f"Setting openai.{mod_key}={val}")
if args.openai_api_type == "azure":
client = openai.AzureOpenAI(
api_key=args.openai_api_key,
azure_endpoint=args.openai_api_base,
api_version=args.openai_api_version,
)
else:
kwargs = dict()
if args.openai_api_base and "openrouter.ai" in args.openai_api_base:
kwargs["default_headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
main_model = models.Model.create(args.model)
client = openai.OpenAI(
api_key=args.openai_api_key,
base_url=args.openai_api_base,
**kwargs,
)
main_model = models.Model.create(args.model, client)
try:
coder = Coder.create(
client,
main_model,
args.edit_format,
io,

View file

@ -1,7 +1,5 @@
import json
import openai
class Model:
name = None
@ -18,12 +16,12 @@ class Model:
completion_price = None
@classmethod
def create(cls, name):
def create(cls, name, client=None):
from .openai import OpenAIModel
from .openrouter import OpenRouterModel
if "openrouter.ai" in openai.api_base:
return OpenRouterModel(name)
if client and client.base_url.host == "openrouter.ai":
return OpenRouterModel(client, name)
return OpenAIModel(name)
def __str__(self):

View file

@ -1,4 +1,3 @@
import openai
import tiktoken
from .model import Model
@ -7,7 +6,7 @@ cached_model_details = None
class OpenRouterModel(Model):
def __init__(self, name):
def __init__(self, client, name):
if name == "gpt-4":
name = "openai/gpt-4"
elif name == "gpt-3.5-turbo":
@ -24,7 +23,7 @@ class OpenRouterModel(Model):
global cached_model_details
if cached_model_details is None:
cached_model_details = openai.Model.list().data
cached_model_details = client.models.list().data
found = next(
(details for details in cached_model_details if details.get("id") == name), None
)

View file

@ -10,13 +10,18 @@ from aider.sendchat import simple_send_with_retries
from .dump import dump # noqa: F401
class OpenAIClientNotProvided(Exception):
pass
class GitRepo:
repo = None
aider_ignore_file = None
aider_ignore_spec = None
aider_ignore_ts = 0
def __init__(self, io, fnames, git_dname, aider_ignore_file=None):
def __init__(self, io, fnames, git_dname, aider_ignore_file=None, client=None):
self.client = client
self.io = io
if git_dname:
@ -101,6 +106,9 @@ class GitRepo:
return self.repo.git_dir
def get_commit_message(self, diffs, context):
if not self.client:
raise OpenAIClientNotProvided
if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error(
f"Diff is too large for {models.GPT35.name} to generate a commit message."
@ -120,7 +128,7 @@ class GitRepo:
]
for model in models.Model.commit_message_models():
commit_message = simple_send_with_retries(model.name, messages)
commit_message = simple_send_with_retries(self.client, model.name, messages)
if commit_message:
break

View file

@ -6,11 +6,11 @@ import openai
import requests
# from diskcache import Cache
from openai.error import (
from openai import (
APIConnectionError,
APIError,
InternalServerError,
RateLimitError,
ServiceUnavailableError,
Timeout,
)
@ -24,7 +24,7 @@ CACHE = None
(
Timeout,
APIError,
ServiceUnavailableError,
InternalServerError,
RateLimitError,
APIConnectionError,
requests.exceptions.ConnectionError,
@ -34,7 +34,7 @@ CACHE = None
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
),
)
def send_with_retries(model_name, messages, functions, stream):
def send_with_retries(client, model_name, messages, functions, stream):
kwargs = dict(
model=model_name,
messages=messages,
@ -44,15 +44,6 @@ def send_with_retries(model_name, messages, functions, stream):
if functions is not None:
kwargs["functions"] = functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
if "openrouter.ai" in openai.api_base:
kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
key = json.dumps(kwargs, sort_keys=True).encode()
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
@ -61,7 +52,7 @@ def send_with_retries(model_name, messages, functions, stream):
if not stream and CACHE is not None and key in CACHE:
return hash_object, CACHE[key]
res = openai.ChatCompletion.create(**kwargs)
res = client.chat.completions.create(**kwargs)
if not stream and CACHE is not None:
CACHE[key] = res
@ -69,14 +60,15 @@ def send_with_retries(model_name, messages, functions, stream):
return hash_object, res
def simple_send_with_retries(model_name, messages):
def simple_send_with_retries(client, model_name, messages):
try:
_hash, response = send_with_retries(
client=client,
model_name=model_name,
messages=messages,
functions=None,
stream=False,
)
return response.choices[0].message.content
except (AttributeError, openai.error.InvalidRequestError):
except (AttributeError, openai.BadRequestError):
return

View file

@ -4,7 +4,6 @@ import tempfile
import time
import numpy as np
import openai
try:
import soundfile as sf
@ -27,7 +26,7 @@ class Voice:
threshold = 0.15
def __init__(self):
def __init__(self, client):
if sf is None:
raise SoundDeviceError
try:
@ -38,6 +37,8 @@ class Voice:
except (OSError, ModuleNotFoundError):
raise SoundDeviceError
self.client = client
def callback(self, indata, frames, time, status):
"""This is called (from a separate thread) for each audio block."""
rms = np.sqrt(np.mean(indata**2))
@ -88,9 +89,11 @@ class Voice:
file.write(self.q.get())
with open(filename, "rb") as fh:
transcript = openai.Audio.transcribe("whisper-1", fh, prompt=history, language=language)
transcript = self.client.audio.transcriptions.create(
model="whisper-1", file=fh, prompt=history, language=language
)
text = transcript["text"]
text = transcript.text
return text

View file

@ -18,7 +18,6 @@ import git
import lox
import matplotlib.pyplot as plt
import numpy as np
import openai
import pandas as pd
import prompts
import typer
@ -631,8 +630,6 @@ def run_test(
show_fnames = ",".join(map(str, fnames))
print("fnames:", show_fnames)
openai.api_key = os.environ["OPENAI_API_KEY"]
coder = Coder.create(
main_model,
edit_format,

View file

@ -1,3 +1,6 @@
#
# pip-compile requirements.in
#
configargparse
GitPython
openai

View file

@ -4,66 +4,67 @@
#
# pip-compile requirements.in
#
aiohttp==3.8.6
# via openai
aiosignal==1.3.1
# via aiohttp
async-timeout==4.0.3
# via aiohttp
annotated-types==0.6.0
# via pydantic
anyio==3.7.1
# via
# httpx
# openai
attrs==23.1.0
# via
# aiohttp
# jsonschema
# referencing
backoff==2.2.1
# via -r requirements.in
certifi==2023.7.22
# via requests
certifi==2023.11.17
# via
# httpcore
# httpx
# requests
cffi==1.16.0
# via
# sounddevice
# soundfile
charset-normalizer==3.3.2
# via
# aiohttp
# requests
# via requests
configargparse==1.7
# via -r requirements.in
diskcache==5.6.3
# via -r requirements.in
frozenlist==1.4.0
# via
# aiohttp
# aiosignal
distro==1.8.0
# via openai
gitdb==4.0.11
# via gitpython
gitpython==3.1.40
# via -r requirements.in
grep-ast==0.2.4
# via -r requirements.in
idna==3.4
h11==0.14.0
# via httpcore
httpcore==1.0.2
# via httpx
httpx==0.25.2
# via openai
idna==3.6
# via
# anyio
# httpx
# requests
# yarl
jsonschema==4.19.2
jsonschema==4.20.0
# via -r requirements.in
jsonschema-specifications==2023.7.1
jsonschema-specifications==2023.11.2
# via jsonschema
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
multidict==6.0.4
# via
# aiohttp
# yarl
networkx==3.2.1
# via -r requirements.in
numpy==1.26.1
numpy==1.26.2
# via
# -r requirements.in
# scipy
openai==0.28.1
openai==1.3.7
# via -r requirements.in
packaging==23.2
# via -r requirements.in
@ -71,49 +72,59 @@ pathspec==0.11.2
# via
# -r requirements.in
# grep-ast
prompt-toolkit==3.0.39
prompt-toolkit==3.0.41
# via -r requirements.in
pycparser==2.21
# via cffi
pygments==2.16.1
pydantic==2.5.2
# via openai
pydantic-core==2.14.5
# via pydantic
pygments==2.17.2
# via rich
pyyaml==6.0.1
# via -r requirements.in
referencing==0.30.2
referencing==0.31.1
# via
# jsonschema
# jsonschema-specifications
regex==2023.10.3
# via tiktoken
requests==2.31.0
# via
# openai
# tiktoken
rich==13.6.0
# via tiktoken
rich==13.7.0
# via -r requirements.in
rpds-py==0.10.6
rpds-py==0.13.2
# via
# jsonschema
# referencing
scipy==1.11.3
scipy==1.11.4
# via -r requirements.in
smmap==5.0.1
# via gitdb
sniffio==1.3.0
# via
# anyio
# httpx
# openai
sounddevice==0.4.6
# via -r requirements.in
soundfile==0.12.1
# via -r requirements.in
tiktoken==0.5.1
tiktoken==0.5.2
# via -r requirements.in
tqdm==4.66.1
# via openai
tree-sitter==0.20.2
tree-sitter==0.20.4
# via tree-sitter-languages
tree-sitter-languages==1.8.0
# via grep-ast
urllib3==2.0.7
typing-extensions==4.8.0
# via
# openai
# pydantic
# pydantic-core
urllib3==2.1.0
# via requests
wcwidth==0.2.9
wcwidth==0.2.12
# via prompt-toolkit
yarl==1.9.2
# via aiohttp

View file

@ -341,12 +341,12 @@ class TestCoder(unittest.TestCase):
coder = Coder.create(models.GPT4, None, mock_io)
# Set up the mock to raise InvalidRequestError
mock_chat_completion_create.side_effect = openai.error.InvalidRequestError(
mock_chat_completion_create.side_effect = openai.BadRequestError(
"Invalid request", "param"
)
# Call the run method and assert that InvalidRequestError is raised
with self.assertRaises(openai.error.InvalidRequestError):
with self.assertRaises(openai.BadRequestError):
coder.run(with_message="hi")
def test_new_file_edit_one_commit(self):

View file

@ -24,12 +24,13 @@ class TestModels(unittest.TestCase):
model = Model.create("gpt-4-32k-2123")
self.assertEqual(model.max_context_tokens, 32 * 1024)
@patch("openai.Model.list")
@patch("openai.resources.Models.list")
def test_openrouter_model_properties(self, mock_model_list):
import openai
# import openai
old_base = openai.api_base
openai.api_base = "https://openrouter.ai/api/v1"
# old_base = openai.api_base
# TODO: fixme
# openai.api_base = "https://openrouter.ai/api/v1"
mock_model_list.return_value = {
"data": [
{
@ -49,7 +50,8 @@ class TestModels(unittest.TestCase):
self.assertEqual(model.max_context_tokens, 8192)
self.assertEqual(model.prompt_price, 0.06)
self.assertEqual(model.completion_price, 0.12)
openai.api_base = old_base
# TODO: fixme
# openai.api_base = old_base
if __name__ == "__main__":

View file

@ -14,7 +14,7 @@ class TestSendChat(unittest.TestCase):
# Set up the mock to raise RateLimitError on
# the first call and return None on the second call
mock_chat_completion_create.side_effect = [
openai.error.RateLimitError("Rate limit exceeded"),
openai.RateLimitError("Rate limit exceeded"),
None,
]