refactor: move API key handling earlier in startup sequence

This commit is contained in:
Paul Gauthier 2024-12-07 10:52:19 -08:00 committed by Paul Gauthier (aider)
parent 7ddcc30e8d
commit 935c39e341

View file

@ -451,12 +451,63 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
args, unknown = parser.parse_known_args(argv)
dump(os.environ.get("ANTHROPIC_API_KEY"))
# Load the .env file specified in the arguments
loaded_dotenvs = load_dotenv_files(git_root, args.env_file, args.encoding)
dump(os.environ.get("ANTHROPIC_API_KEY"))
# Parse again to include any arguments that might have been defined in .env
args = parser.parse_args(argv)
dump(os.environ.get("ANTHROPIC_API_KEY"))
# Process any environment variables set via --set-env
if args.set_env:
for env_setting in args.set_env:
try:
name, value = env_setting.split("=", 1)
os.environ[name.strip()] = value.strip()
except ValueError:
io.tool_error(f"Invalid --set-env format: {env_setting}")
io.tool_output("Format should be: ENV_VAR_NAME=value")
return 1
# Process any API keys set via --api-key
if args.api_key:
for api_setting in args.api_key:
try:
provider, key = api_setting.split("=", 1)
env_var = f"{provider.strip().upper()}_API_KEY"
os.environ[env_var] = key.strip()
dump(env_var)
except ValueError:
io.tool_error(f"Invalid --api-key format: {api_setting}")
io.tool_output("Format should be: provider=key")
return 1
dump(os.environ.get("ANTHROPIC_API_KEY"))
# AI: for each of these, add io.tool_warning("--xxx is deprecated, use ---yyy") and either use --api-key foo=<key> or --set-env FOO=value
if args.anthropic_api_key:
os.environ["ANTHROPIC_API_KEY"] = args.anthropic_api_key
if args.openai_api_key:
os.environ["OPENAI_API_KEY"] = args.openai_api_key
if args.openai_api_base:
os.environ["OPENAI_API_BASE"] = args.openai_api_base
if args.openai_api_version:
os.environ["OPENAI_API_VERSION"] = args.openai_api_version
if args.openai_api_type:
os.environ["OPENAI_API_TYPE"] = args.openai_api_type
if args.openai_organization_id:
os.environ["OPENAI_ORGANIZATION"] = args.openai_organization_id
# ... down to here AI!
dump(os.environ.get("ANTHROPIC_API_KEY"))
if args.analytics_disable:
analytics = Analytics(permanently_disable=True)
print("Analytics have been permanently disabled.")
@ -526,29 +577,6 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
io = get_io(False)
io.tool_warning("Terminal does not support pretty output (UnicodeDecodeError)")
# Process any environment variables set via --set-env
if args.set_env:
for env_setting in args.set_env:
try:
name, value = env_setting.split("=", 1)
os.environ[name.strip()] = value.strip()
except ValueError:
io.tool_error(f"Invalid --set-env format: {env_setting}")
io.tool_output("Format should be: ENV_VAR_NAME=value")
return 1
# Process any API keys set via --api-key
if args.api_key:
for api_setting in args.api_key:
try:
provider, key = api_setting.split("=", 1)
env_var = f"{provider.strip().upper()}_API_KEY"
os.environ[env_var] = key.strip()
except ValueError:
io.tool_error(f"Invalid --api-key format: {api_setting}")
io.tool_output("Format should be: provider=key")
return 1
analytics = Analytics(logfile=args.analytics_log, permanently_disable=args.analytics_disable)
if args.analytics is not False:
if analytics.need_to_ask(args.analytics):
@ -669,20 +697,6 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
is_first_run = is_first_run_of_new_version(io, verbose=args.verbose)
check_and_load_imports(io, is_first_run, verbose=args.verbose)
if args.anthropic_api_key:
os.environ["ANTHROPIC_API_KEY"] = args.anthropic_api_key
if args.openai_api_key:
os.environ["OPENAI_API_KEY"] = args.openai_api_key
if args.openai_api_base:
os.environ["OPENAI_API_BASE"] = args.openai_api_base
if args.openai_api_version:
os.environ["OPENAI_API_VERSION"] = args.openai_api_version
if args.openai_api_type:
os.environ["OPENAI_API_TYPE"] = args.openai_api_type
if args.openai_organization_id:
os.environ["OPENAI_ORGANIZATION"] = args.openai_organization_id
register_models(git_root, args.model_settings_file, io, verbose=args.verbose)
register_litellm_models(git_root, args.model_metadata_file, io, verbose=args.verbose)