refactor initialization of openai into main

This commit is contained in:
Paul Gauthier 2023-07-11 15:06:50 -07:00
parent fa3283802a
commit 084132a5f8
2 changed files with 12 additions and 11 deletions

View file

@ -53,8 +53,6 @@ class Coder:
main_model, main_model,
edit_format, edit_format,
io, io,
openai_api_key,
openai_api_base="https://api.openai.com/v1",
**kwargs, **kwargs,
): ):
from . import ( from . import (
@ -65,9 +63,6 @@ class Coder:
WholeFileFunctionCoder, WholeFileFunctionCoder,
) )
openai.api_key = openai_api_key
openai.api_base = openai_api_base
if not main_model: if not main_model:
main_model = models.GPT35_16k main_model = models.GPT35_16k
@ -629,6 +624,8 @@ class Coder:
) )
if functions is not None: if functions is not None:
kwargs["functions"] = self.functions kwargs["functions"] = self.functions
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes # Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode()) hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())

View file

@ -4,6 +4,7 @@ from pathlib import Path
import configargparse import configargparse
import git import git
import openai
from aider import __version__, models from aider import __version__, models
from aider.coders import Coder from aider.coders import Coder
@ -75,7 +76,6 @@ def main(args=None, input=None, output=None):
model_group.add_argument( model_group.add_argument(
"--openai-api-base", "--openai-api-base",
metavar="OPENAI_API_BASE", metavar="OPENAI_API_BASE",
default="https://api.openai.com/v1",
help="Specify the OpenAI API base endpoint (default: https://api.openai.com/v1)", help="Specify the OpenAI API base endpoint (default: https://api.openai.com/v1)",
) )
model_group.add_argument( model_group.add_argument(
@ -347,12 +347,19 @@ def main(args=None, input=None, output=None):
main_model = models.Model(args.model) main_model = models.Model(args.model)
openai.api_key = args.openai_api_key
for attr in ("base", "type", "version", "deployment_id"):
arg_key = f"openai_api_{attr}"
val = getattr(args, arg_key)
if val is not None:
mod_key = f"api_{attr}"
setattr(openai, mod_key, val)
io.tool_output(f"Setting openai.{mod_key}={val}")
coder = Coder.create( coder = Coder.create(
main_model, main_model,
args.edit_format, args.edit_format,
io, io,
args.openai_api_key,
args.openai_api_base,
## ##
fnames=args.files, fnames=args.files,
pretty=args.pretty, pretty=args.pretty,
@ -366,9 +373,6 @@ def main(args=None, input=None, output=None):
code_theme=args.code_theme, code_theme=args.code_theme,
stream=args.stream, stream=args.stream,
use_git=args.git, use_git=args.git,
openai_api_type=args.openai_api_type,
openai_api_version=args.openai_api_version,
openai_api_deployment_id=args.openai_api_deployment_id,
) )
if args.dirty_commits: if args.dirty_commits: