Merge pull request #88 from paul-gauthier/azure

Added args to configure openai module to access Azure
This commit is contained in:
paul-gauthier 2023-07-12 07:32:10 -07:00 committed by GitHub
commit 549a1a7640
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 95 additions and 46 deletions

View file

@ -53,8 +53,6 @@ class Coder:
main_model,
edit_format,
io,
openai_api_key,
openai_api_base="https://api.openai.com/v1",
**kwargs,
):
from . import (
@ -65,9 +63,6 @@ class Coder:
WholeFileFunctionCoder,
)
openai.api_key = openai_api_key
openai.api_base = openai_api_base
if not main_model:
main_model = models.GPT35_16k
@ -630,6 +625,12 @@ class Coder:
if functions is not None:
kwargs["functions"] = self.functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
self.chat_completion_call_hashes.append(hash_object.hexdigest())

View file

@ -4,6 +4,7 @@ from pathlib import Path
import configargparse
import git
import openai
from aider import __version__, models
from aider.coders import Coder
@ -75,8 +76,27 @@ def main(args=None, input=None, output=None):
model_group.add_argument(
"--openai-api-base",
metavar="OPENAI_API_BASE",
default="https://api.openai.com/v1",
help="Specify the OpenAI API base endpoint (default: https://api.openai.com/v1)",
help="Specify the openai.api_base (default: https://api.openai.com/v1)",
)
model_group.add_argument(
"--openai-api-type",
metavar="OPENAI_API_TYPE",
help="Specify the openai.api_type",
)
model_group.add_argument(
"--openai-api-version",
metavar="OPENAI_API_VERSION",
help="Specify the openai.api_version",
)
model_group.add_argument(
"--openai-api-deployment-id",
metavar="OPENAI_API_DEPLOYMENT_ID",
help="Specify the deployment_id arg to be passed to openai.ChatCompletion.create()",
)
model_group.add_argument(
"--openai-api-engine",
metavar="OPENAI_API_ENGINE",
help="Specify the engine arg to be passed to openai.ChatCompletion.create()",
)
model_group.add_argument(
"--edit-format",
@ -334,12 +354,19 @@ def main(args=None, input=None, output=None):
main_model = models.Model(args.model)
openai.api_key = args.openai_api_key
for attr in ("base", "type", "version", "deployment_id", "engine"):
arg_key = f"openai_api_{attr}"
val = getattr(args, arg_key)
if val is not None:
mod_key = f"api_{attr}"
setattr(openai, mod_key, val)
io.tool_output(f"Setting openai.{mod_key}={val}")
coder = Coder.create(
main_model,
args.edit_format,
io,
args.openai_api_key,
args.openai_api_base,
##
fnames=args.files,
pretty=args.pretty,