roughed in openai 1.x

This commit is contained in:
Paul Gauthier 2023-12-05 07:37:05 -08:00
parent fd34766aa9
commit 6ebc142377
15 changed files with 136 additions and 110 deletions

View file

@ -53,6 +53,7 @@ class Coder:
@classmethod @classmethod
def create( def create(
self, self,
client,
main_model=None, main_model=None,
edit_format=None, edit_format=None,
io=None, io=None,
@ -65,7 +66,7 @@ class Coder:
main_model = models.GPT4 main_model = models.GPT4
if not skip_model_availabily_check and not main_model.always_available: if not skip_model_availabily_check and not main_model.always_available:
if not check_model_availability(io, main_model): if not check_model_availability(io, client, main_model):
if main_model != models.GPT4: if main_model != models.GPT4:
io.tool_error( io.tool_error(
f"API key does not support {main_model.name}, falling back to" f"API key does not support {main_model.name}, falling back to"
@ -77,14 +78,15 @@ class Coder:
edit_format = main_model.edit_format edit_format = main_model.edit_format
if edit_format == "diff": if edit_format == "diff":
return EditBlockCoder(main_model, io, **kwargs) return EditBlockCoder(client, main_model, io, **kwargs)
elif edit_format == "whole": elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs) return WholeFileCoder(client, main_model, io, **kwargs)
else: else:
raise ValueError(f"Unknown edit format {edit_format}") raise ValueError(f"Unknown edit format {edit_format}")
def __init__( def __init__(
self, self,
client,
main_model, main_model,
io, io,
fnames=None, fnames=None,
@ -103,6 +105,8 @@ class Coder:
voice_language=None, voice_language=None,
aider_ignore_file=None, aider_ignore_file=None,
): ):
self.client = client
if not fnames: if not fnames:
fnames = [] fnames = []
@ -190,6 +194,7 @@ class Coder:
self.io.tool_output(f"Added {fname} to the chat.") self.io.tool_output(f"Added {fname} to the chat.")
self.summarizer = ChatSummary( self.summarizer = ChatSummary(
self.client,
models.Model.weak_model(), models.Model.weak_model(),
self.main_model.max_chat_history_tokens, self.main_model.max_chat_history_tokens,
) )
@ -470,7 +475,7 @@ class Coder:
interrupted = self.send(messages, functions=self.functions) interrupted = self.send(messages, functions=self.functions)
except ExhaustedContextWindow: except ExhaustedContextWindow:
exhausted = True exhausted = True
except openai.error.InvalidRequestError as err: except openai.BadRequestError as err:
if "maximum context length" in str(err): if "maximum context length" in str(err):
exhausted = True exhausted = True
else: else:
@ -587,7 +592,9 @@ class Coder:
interrupted = False interrupted = False
try: try:
hash_object, completion = send_with_retries(model, messages, functions, self.stream) hash_object, completion = send_with_retries(
self.client, model, messages, functions, self.stream
)
self.chat_completion_call_hashes.append(hash_object.hexdigest()) self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream: if self.stream:
@ -941,8 +948,8 @@ class Coder:
return True return True
def check_model_availability(io, main_model): def check_model_availability(io, client, main_model):
available_models = openai.Model.list() available_models = client.models.list()
model_ids = sorted(model.id for model in available_models["data"]) model_ids = sorted(model.id for model in available_models["data"])
if main_model.name in model_ids: if main_model.name in model_ids:
return True return True

View file

@ -462,7 +462,7 @@ class Commands:
if not self.voice: if not self.voice:
try: try:
self.voice = voice.Voice() self.voice = voice.Voice(self.coder.client)
except voice.SoundDeviceError: except voice.SoundDeviceError:
self.io.tool_error( self.io.tool_error(
"Unable to import `sounddevice` and/or `soundfile`, is portaudio installed?" "Unable to import `sounddevice` and/or `soundfile`, is portaudio installed?"

View file

@ -7,7 +7,8 @@ from aider.sendchat import simple_send_with_retries
class ChatSummary: class ChatSummary:
def __init__(self, model=models.Model.weak_model(), max_tokens=1024): def __init__(self, client, model=models.Model.weak_model(), max_tokens=1024):
self.client = client
self.tokenizer = model.tokenizer self.tokenizer = model.tokenizer
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.model = model self.model = model
@ -84,7 +85,7 @@ class ChatSummary:
dict(role="user", content=content), dict(role="user", content=content),
] ]
summary = simple_send_with_retries(self.model.name, messages) summary = simple_send_with_retries(self.client, self.model.name, messages)
if summary is None: if summary is None:
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}") raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
summary = prompts.summary_prefix + summary summary = prompts.summary_prefix + summary

View file

@ -176,27 +176,23 @@ def main(argv=None, input=None, output=None, force_git_root=None):
model_group.add_argument( model_group.add_argument(
"--openai-api-base", "--openai-api-base",
metavar="OPENAI_API_BASE", metavar="OPENAI_API_BASE",
help="Specify the openai.api_base (default: https://api.openai.com/v1)", help="Specify the api_base (default: https://api.openai.com/v1)",
) )
model_group.add_argument( model_group.add_argument(
"--openai-api-type", "--openai-api-type",
metavar="OPENAI_API_TYPE", metavar="OPENAI_API_TYPE",
help="Specify the openai.api_type", help="Specify the api_type",
) )
model_group.add_argument( model_group.add_argument(
"--openai-api-version", "--openai-api-version",
metavar="OPENAI_API_VERSION", metavar="OPENAI_API_VERSION",
help="Specify the openai.api_version", help="Specify the api_version",
) )
# TODO: use deployment_id
model_group.add_argument( model_group.add_argument(
"--openai-api-deployment-id", "--openai-api-deployment-id",
metavar="OPENAI_API_DEPLOYMENT_ID", metavar="OPENAI_API_DEPLOYMENT_ID",
help="Specify the deployment_id arg to be passed to openai.ChatCompletion.create()", help="Specify the deployment_id",
)
model_group.add_argument(
"--openai-api-engine",
metavar="OPENAI_API_ENGINE",
help="Specify the engine arg to be passed to openai.ChatCompletion.create()",
) )
model_group.add_argument( model_group.add_argument(
"--edit-format", "--edit-format",
@ -492,19 +488,28 @@ def main(argv=None, input=None, output=None, force_git_root=None):
) )
return 1 return 1
openai.api_key = args.openai_api_key if args.openai_api_type == "azure":
for attr in ("base", "type", "version", "deployment_id", "engine"): client = openai.AzureOpenAI(
arg_key = f"openai_api_{attr}" api_key=args.openai_api_key,
val = getattr(args, arg_key) azure_endpoint=args.openai_api_base,
if val is not None: api_version=args.openai_api_version,
mod_key = f"api_{attr}" )
setattr(openai, mod_key, val) else:
io.tool_output(f"Setting openai.{mod_key}={val}") kwargs = dict()
if args.openai_api_base and "openrouter.ai" in args.openai_api_base:
kwargs["default_headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
main_model = models.Model.create(args.model) client = openai.OpenAI(
api_key=args.openai_api_key,
base_url=args.openai_api_base,
**kwargs,
)
main_model = models.Model.create(args.model, client)
try: try:
coder = Coder.create( coder = Coder.create(
client,
main_model, main_model,
args.edit_format, args.edit_format,
io, io,

View file

@ -1,7 +1,5 @@
import json import json
import openai
class Model: class Model:
name = None name = None
@ -18,12 +16,12 @@ class Model:
completion_price = None completion_price = None
@classmethod @classmethod
def create(cls, name): def create(cls, name, client=None):
from .openai import OpenAIModel from .openai import OpenAIModel
from .openrouter import OpenRouterModel from .openrouter import OpenRouterModel
if "openrouter.ai" in openai.api_base: if client and client.base_url.host == "openrouter.ai":
return OpenRouterModel(name) return OpenRouterModel(client, name)
return OpenAIModel(name) return OpenAIModel(name)
def __str__(self): def __str__(self):

View file

@ -1,4 +1,3 @@
import openai
import tiktoken import tiktoken
from .model import Model from .model import Model
@ -7,7 +6,7 @@ cached_model_details = None
class OpenRouterModel(Model): class OpenRouterModel(Model):
def __init__(self, name): def __init__(self, client, name):
if name == "gpt-4": if name == "gpt-4":
name = "openai/gpt-4" name = "openai/gpt-4"
elif name == "gpt-3.5-turbo": elif name == "gpt-3.5-turbo":
@ -24,7 +23,7 @@ class OpenRouterModel(Model):
global cached_model_details global cached_model_details
if cached_model_details is None: if cached_model_details is None:
cached_model_details = openai.Model.list().data cached_model_details = client.models.list().data
found = next( found = next(
(details for details in cached_model_details if details.get("id") == name), None (details for details in cached_model_details if details.get("id") == name), None
) )

View file

@ -10,13 +10,18 @@ from aider.sendchat import simple_send_with_retries
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
class OpenAIClientNotProvided(Exception):
pass
class GitRepo: class GitRepo:
repo = None repo = None
aider_ignore_file = None aider_ignore_file = None
aider_ignore_spec = None aider_ignore_spec = None
aider_ignore_ts = 0 aider_ignore_ts = 0
def __init__(self, io, fnames, git_dname, aider_ignore_file=None): def __init__(self, io, fnames, git_dname, aider_ignore_file=None, client=None):
self.client = client
self.io = io self.io = io
if git_dname: if git_dname:
@ -101,6 +106,9 @@ class GitRepo:
return self.repo.git_dir return self.repo.git_dir
def get_commit_message(self, diffs, context): def get_commit_message(self, diffs, context):
if not self.client:
raise OpenAIClientNotProvided
if len(diffs) >= 4 * 1024 * 4: if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error( self.io.tool_error(
f"Diff is too large for {models.GPT35.name} to generate a commit message." f"Diff is too large for {models.GPT35.name} to generate a commit message."
@ -120,7 +128,7 @@ class GitRepo:
] ]
for model in models.Model.commit_message_models(): for model in models.Model.commit_message_models():
commit_message = simple_send_with_retries(model.name, messages) commit_message = simple_send_with_retries(self.client, model.name, messages)
if commit_message: if commit_message:
break break

View file

@ -6,11 +6,11 @@ import openai
import requests import requests
# from diskcache import Cache # from diskcache import Cache
from openai.error import ( from openai import (
APIConnectionError, APIConnectionError,
APIError, APIError,
InternalServerError,
RateLimitError, RateLimitError,
ServiceUnavailableError,
Timeout, Timeout,
) )
@ -24,7 +24,7 @@ CACHE = None
( (
Timeout, Timeout,
APIError, APIError,
ServiceUnavailableError, InternalServerError,
RateLimitError, RateLimitError,
APIConnectionError, APIConnectionError,
requests.exceptions.ConnectionError, requests.exceptions.ConnectionError,
@ -34,7 +34,7 @@ CACHE = None
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds." f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
), ),
) )
def send_with_retries(model_name, messages, functions, stream): def send_with_retries(client, model_name, messages, functions, stream):
kwargs = dict( kwargs = dict(
model=model_name, model=model_name,
messages=messages, messages=messages,
@ -44,15 +44,6 @@ def send_with_retries(model_name, messages, functions, stream):
if functions is not None: if functions is not None:
kwargs["functions"] = functions kwargs["functions"] = functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
if "openrouter.ai" in openai.api_base:
kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
key = json.dumps(kwargs, sort_keys=True).encode() key = json.dumps(kwargs, sort_keys=True).encode()
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes # Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
@ -61,7 +52,7 @@ def send_with_retries(model_name, messages, functions, stream):
if not stream and CACHE is not None and key in CACHE: if not stream and CACHE is not None and key in CACHE:
return hash_object, CACHE[key] return hash_object, CACHE[key]
res = openai.ChatCompletion.create(**kwargs) res = client.chat.completions.create(**kwargs)
if not stream and CACHE is not None: if not stream and CACHE is not None:
CACHE[key] = res CACHE[key] = res
@ -69,14 +60,15 @@ def send_with_retries(model_name, messages, functions, stream):
return hash_object, res return hash_object, res
def simple_send_with_retries(model_name, messages): def simple_send_with_retries(client, model_name, messages):
try: try:
_hash, response = send_with_retries( _hash, response = send_with_retries(
client=client,
model_name=model_name, model_name=model_name,
messages=messages, messages=messages,
functions=None, functions=None,
stream=False, stream=False,
) )
return response.choices[0].message.content return response.choices[0].message.content
except (AttributeError, openai.error.InvalidRequestError): except (AttributeError, openai.BadRequestError):
return return

View file

@ -4,7 +4,6 @@ import tempfile
import time import time
import numpy as np import numpy as np
import openai
try: try:
import soundfile as sf import soundfile as sf
@ -27,7 +26,7 @@ class Voice:
threshold = 0.15 threshold = 0.15
def __init__(self): def __init__(self, client):
if sf is None: if sf is None:
raise SoundDeviceError raise SoundDeviceError
try: try:
@ -38,6 +37,8 @@ class Voice:
except (OSError, ModuleNotFoundError): except (OSError, ModuleNotFoundError):
raise SoundDeviceError raise SoundDeviceError
self.client = client
def callback(self, indata, frames, time, status): def callback(self, indata, frames, time, status):
"""This is called (from a separate thread) for each audio block.""" """This is called (from a separate thread) for each audio block."""
rms = np.sqrt(np.mean(indata**2)) rms = np.sqrt(np.mean(indata**2))
@ -88,9 +89,11 @@ class Voice:
file.write(self.q.get()) file.write(self.q.get())
with open(filename, "rb") as fh: with open(filename, "rb") as fh:
transcript = openai.Audio.transcribe("whisper-1", fh, prompt=history, language=language) transcript = self.client.audio.transcriptions.create(
model="whisper-1", file=fh, prompt=history, language=language
)
text = transcript["text"] text = transcript.text
return text return text

View file

@ -18,7 +18,6 @@ import git
import lox import lox
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import openai
import pandas as pd import pandas as pd
import prompts import prompts
import typer import typer
@ -631,8 +630,6 @@ def run_test(
show_fnames = ",".join(map(str, fnames)) show_fnames = ",".join(map(str, fnames))
print("fnames:", show_fnames) print("fnames:", show_fnames)
openai.api_key = os.environ["OPENAI_API_KEY"]
coder = Coder.create( coder = Coder.create(
main_model, main_model,
edit_format, edit_format,

View file

@ -1,3 +1,6 @@
#
# pip-compile requirements.in
#
configargparse configargparse
GitPython GitPython
openai openai

View file

@ -4,66 +4,67 @@
# #
# pip-compile requirements.in # pip-compile requirements.in
# #
aiohttp==3.8.6 annotated-types==0.6.0
# via openai # via pydantic
aiosignal==1.3.1 anyio==3.7.1
# via aiohttp # via
async-timeout==4.0.3 # httpx
# via aiohttp # openai
attrs==23.1.0 attrs==23.1.0
# via # via
# aiohttp
# jsonschema # jsonschema
# referencing # referencing
backoff==2.2.1 backoff==2.2.1
# via -r requirements.in # via -r requirements.in
certifi==2023.7.22 certifi==2023.11.17
# via requests # via
# httpcore
# httpx
# requests
cffi==1.16.0 cffi==1.16.0
# via # via
# sounddevice # sounddevice
# soundfile # soundfile
charset-normalizer==3.3.2 charset-normalizer==3.3.2
# via # via requests
# aiohttp
# requests
configargparse==1.7 configargparse==1.7
# via -r requirements.in # via -r requirements.in
diskcache==5.6.3 diskcache==5.6.3
# via -r requirements.in # via -r requirements.in
frozenlist==1.4.0 distro==1.8.0
# via # via openai
# aiohttp
# aiosignal
gitdb==4.0.11 gitdb==4.0.11
# via gitpython # via gitpython
gitpython==3.1.40 gitpython==3.1.40
# via -r requirements.in # via -r requirements.in
grep-ast==0.2.4 grep-ast==0.2.4
# via -r requirements.in # via -r requirements.in
idna==3.4 h11==0.14.0
# via httpcore
httpcore==1.0.2
# via httpx
httpx==0.25.2
# via openai
idna==3.6
# via # via
# anyio
# httpx
# requests # requests
# yarl jsonschema==4.20.0
jsonschema==4.19.2
# via -r requirements.in # via -r requirements.in
jsonschema-specifications==2023.7.1 jsonschema-specifications==2023.11.2
# via jsonschema # via jsonschema
markdown-it-py==3.0.0 markdown-it-py==3.0.0
# via rich # via rich
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
multidict==6.0.4
# via
# aiohttp
# yarl
networkx==3.2.1 networkx==3.2.1
# via -r requirements.in # via -r requirements.in
numpy==1.26.1 numpy==1.26.2
# via # via
# -r requirements.in # -r requirements.in
# scipy # scipy
openai==0.28.1 openai==1.3.7
# via -r requirements.in # via -r requirements.in
packaging==23.2 packaging==23.2
# via -r requirements.in # via -r requirements.in
@ -71,49 +72,59 @@ pathspec==0.11.2
# via # via
# -r requirements.in # -r requirements.in
# grep-ast # grep-ast
prompt-toolkit==3.0.39 prompt-toolkit==3.0.41
# via -r requirements.in # via -r requirements.in
pycparser==2.21 pycparser==2.21
# via cffi # via cffi
pygments==2.16.1 pydantic==2.5.2
# via openai
pydantic-core==2.14.5
# via pydantic
pygments==2.17.2
# via rich # via rich
pyyaml==6.0.1 pyyaml==6.0.1
# via -r requirements.in # via -r requirements.in
referencing==0.30.2 referencing==0.31.1
# via # via
# jsonschema # jsonschema
# jsonschema-specifications # jsonschema-specifications
regex==2023.10.3 regex==2023.10.3
# via tiktoken # via tiktoken
requests==2.31.0 requests==2.31.0
# via # via tiktoken
# openai rich==13.7.0
# tiktoken
rich==13.6.0
# via -r requirements.in # via -r requirements.in
rpds-py==0.10.6 rpds-py==0.13.2
# via # via
# jsonschema # jsonschema
# referencing # referencing
scipy==1.11.3 scipy==1.11.4
# via -r requirements.in # via -r requirements.in
smmap==5.0.1 smmap==5.0.1
# via gitdb # via gitdb
sniffio==1.3.0
# via
# anyio
# httpx
# openai
sounddevice==0.4.6 sounddevice==0.4.6
# via -r requirements.in # via -r requirements.in
soundfile==0.12.1 soundfile==0.12.1
# via -r requirements.in # via -r requirements.in
tiktoken==0.5.1 tiktoken==0.5.2
# via -r requirements.in # via -r requirements.in
tqdm==4.66.1 tqdm==4.66.1
# via openai # via openai
tree-sitter==0.20.2 tree-sitter==0.20.4
# via tree-sitter-languages # via tree-sitter-languages
tree-sitter-languages==1.8.0 tree-sitter-languages==1.8.0
# via grep-ast # via grep-ast
urllib3==2.0.7 typing-extensions==4.8.0
# via
# openai
# pydantic
# pydantic-core
urllib3==2.1.0
# via requests # via requests
wcwidth==0.2.9 wcwidth==0.2.12
# via prompt-toolkit # via prompt-toolkit
yarl==1.9.2
# via aiohttp

View file

@ -341,12 +341,12 @@ class TestCoder(unittest.TestCase):
coder = Coder.create(models.GPT4, None, mock_io) coder = Coder.create(models.GPT4, None, mock_io)
# Set up the mock to raise InvalidRequestError # Set up the mock to raise InvalidRequestError
mock_chat_completion_create.side_effect = openai.error.InvalidRequestError( mock_chat_completion_create.side_effect = openai.BadRequestError(
"Invalid request", "param" "Invalid request", "param"
) )
# Call the run method and assert that InvalidRequestError is raised # Call the run method and assert that InvalidRequestError is raised
with self.assertRaises(openai.error.InvalidRequestError): with self.assertRaises(openai.BadRequestError):
coder.run(with_message="hi") coder.run(with_message="hi")
def test_new_file_edit_one_commit(self): def test_new_file_edit_one_commit(self):

View file

@ -24,12 +24,13 @@ class TestModels(unittest.TestCase):
model = Model.create("gpt-4-32k-2123") model = Model.create("gpt-4-32k-2123")
self.assertEqual(model.max_context_tokens, 32 * 1024) self.assertEqual(model.max_context_tokens, 32 * 1024)
@patch("openai.Model.list") @patch("openai.resources.Models.list")
def test_openrouter_model_properties(self, mock_model_list): def test_openrouter_model_properties(self, mock_model_list):
import openai # import openai
old_base = openai.api_base # old_base = openai.api_base
openai.api_base = "https://openrouter.ai/api/v1" # TODO: fixme
# openai.api_base = "https://openrouter.ai/api/v1"
mock_model_list.return_value = { mock_model_list.return_value = {
"data": [ "data": [
{ {
@ -49,7 +50,8 @@ class TestModels(unittest.TestCase):
self.assertEqual(model.max_context_tokens, 8192) self.assertEqual(model.max_context_tokens, 8192)
self.assertEqual(model.prompt_price, 0.06) self.assertEqual(model.prompt_price, 0.06)
self.assertEqual(model.completion_price, 0.12) self.assertEqual(model.completion_price, 0.12)
openai.api_base = old_base # TODO: fixme
# openai.api_base = old_base
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -14,7 +14,7 @@ class TestSendChat(unittest.TestCase):
# Set up the mock to raise RateLimitError on # Set up the mock to raise RateLimitError on
# the first call and return None on the second call # the first call and return None on the second call
mock_chat_completion_create.side_effect = [ mock_chat_completion_create.side_effect = [
openai.error.RateLimitError("Rate limit exceeded"), openai.RateLimitError("Rate limit exceeded"),
None, None,
] ]