Added --openai-api-engine

This commit is contained in:
Paul Gauthier 2023-07-11 15:11:35 -07:00
parent 084132a5f8
commit d97707a5c0
2 changed files with 14 additions and 5 deletions

View file

@ -624,8 +624,12 @@ class Coder:
) )
if functions is not None: if functions is not None:
kwargs["functions"] = self.functions kwargs["functions"] = self.functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"): if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes # Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode()) hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())

View file

@ -76,22 +76,27 @@ def main(args=None, input=None, output=None):
model_group.add_argument( model_group.add_argument(
"--openai-api-base", "--openai-api-base",
metavar="OPENAI_API_BASE", metavar="OPENAI_API_BASE",
help="Specify the OpenAI API base endpoint (default: https://api.openai.com/v1)", help="Specify the openai.api_base (default: https://api.openai.com/v1)",
) )
model_group.add_argument( model_group.add_argument(
"--openai-api-type", "--openai-api-type",
metavar="OPENAI_API_TYPE", metavar="OPENAI_API_TYPE",
help="Specify the OpenAI API type", help="Specify the openai.api_type",
) )
model_group.add_argument( model_group.add_argument(
"--openai-api-version", "--openai-api-version",
metavar="OPENAI_API_VERSION", metavar="OPENAI_API_VERSION",
help="Specify the OpenAI API version", help="Specify the openai.api_version",
) )
model_group.add_argument( model_group.add_argument(
"--openai-api-deployment-id", "--openai-api-deployment-id",
metavar="OPENAI_API_DEPLOYMENT_ID", metavar="OPENAI_API_DEPLOYMENT_ID",
help="Specify the OpenAI API deployment ID", help="Specify the deployment_id arg to be passed to openai.ChatCompletion.create()",
)
model_group.add_argument(
"--openai-api-engine",
metavar="OPENAI_API_ENGINE",
help="Specify the engine arg to be passed to openai.ChatCompletion.create()",
) )
model_group.add_argument( model_group.add_argument(
"--edit-format", "--edit-format",
@ -348,7 +353,7 @@ def main(args=None, input=None, output=None):
main_model = models.Model(args.model) main_model = models.Model(args.model)
openai.api_key = args.openai_api_key openai.api_key = args.openai_api_key
for attr in ("base", "type", "version", "deployment_id"): for attr in ("base", "type", "version", "deployment_id", "engine"):
arg_key = f"openai_api_{attr}" arg_key = f"openai_api_{attr}"
val = getattr(args, arg_key) val = getattr(args, arg_key)
if val is not None: if val is not None: