roughed in openai 1.x

This commit is contained in:
Paul Gauthier 2023-12-05 07:37:05 -08:00
parent fd34766aa9
commit 6ebc142377
15 changed files with 136 additions and 110 deletions

View file

@ -176,27 +176,23 @@ def main(argv=None, input=None, output=None, force_git_root=None):
model_group.add_argument(
"--openai-api-base",
metavar="OPENAI_API_BASE",
help="Specify the openai.api_base (default: https://api.openai.com/v1)",
help="Specify the api_base (default: https://api.openai.com/v1)",
)
model_group.add_argument(
"--openai-api-type",
metavar="OPENAI_API_TYPE",
help="Specify the openai.api_type",
help="Specify the api_type",
)
model_group.add_argument(
"--openai-api-version",
metavar="OPENAI_API_VERSION",
help="Specify the openai.api_version",
help="Specify the api_version",
)
# TODO: use deployment_id
model_group.add_argument(
"--openai-api-deployment-id",
metavar="OPENAI_API_DEPLOYMENT_ID",
help="Specify the deployment_id arg to be passed to openai.ChatCompletion.create()",
)
model_group.add_argument(
"--openai-api-engine",
metavar="OPENAI_API_ENGINE",
help="Specify the engine arg to be passed to openai.ChatCompletion.create()",
help="Specify the deployment_id",
)
model_group.add_argument(
"--edit-format",
@ -492,19 +488,28 @@ def main(argv=None, input=None, output=None, force_git_root=None):
)
return 1
openai.api_key = args.openai_api_key
for attr in ("base", "type", "version", "deployment_id", "engine"):
arg_key = f"openai_api_{attr}"
val = getattr(args, arg_key)
if val is not None:
mod_key = f"api_{attr}"
setattr(openai, mod_key, val)
io.tool_output(f"Setting openai.{mod_key}={val}")
if args.openai_api_type == "azure":
client = openai.AzureOpenAI(
api_key=args.openai_api_key,
azure_endpoint=args.openai_api_base,
api_version=args.openai_api_version,
)
else:
kwargs = dict()
if args.openai_api_base and "openrouter.ai" in args.openai_api_base:
kwargs["default_headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
main_model = models.Model.create(args.model)
client = openai.OpenAI(
api_key=args.openai_api_key,
base_url=args.openai_api_base,
**kwargs,
)
main_model = models.Model.create(args.model, client)
try:
coder = Coder.create(
client,
main_model,
args.edit_format,
io,