rouged in litellm

This commit is contained in:
Paul Gauthier 2024-04-17 14:15:24 -07:00
parent 93de82c3da
commit b0245d3930
3 changed files with 34 additions and 50 deletions

View file

@ -6,7 +6,7 @@ from pathlib import Path
import configargparse import configargparse
import git import git
import openai import litellm
from aider import __version__, models from aider import __version__, models
from aider.coders import Coder from aider.coders import Coder
@ -16,6 +16,10 @@ from aider.versioncheck import check_version
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
litellm.suppress_debug_info = True
os.environ["OR_SITE_URL"] = "http://aider.chat"
os.environ["OR_APP_NAME"] = "Aider"
def get_git_root(): def get_git_root():
"""Try and guess the git repo, since the conf.yml can be at the repo root""" """Try and guess the git repo, since the conf.yml can be at the repo root"""
@ -169,6 +173,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
core_group.add_argument( core_group.add_argument(
"--skip-model-availability-check", "--skip-model-availability-check",
metavar="SKIP_MODEL_AVAILABILITY_CHECK", metavar="SKIP_MODEL_AVAILABILITY_CHECK",
action=argparse.BooleanOptionalAction,
default=False, default=False,
help="Override to skip model availability check (default: False)", help="Override to skip model availability check (default: False)",
) )
@ -559,39 +564,26 @@ def main(argv=None, input=None, output=None, force_git_root=None):
io.tool_output(*map(scrub_sensitive_info, sys.argv), log_only=True) io.tool_output(*map(scrub_sensitive_info, sys.argv), log_only=True)
if not args.openai_api_key: if args.openai_api_key:
if os.name == "nt": os.environ["OPENAI_API_KEY"] = args.openai_api_key
io.tool_error( if args.openai_api_base:
"No OpenAI API key provided. Use --openai-api-key or setx OPENAI_API_KEY." os.environ["OPENAI_API_BASE"] = args.openai_api_base
) if args.openai_api_version:
else: os.environ["AZURE_API_VERSION"] = args.openai_api_version
io.tool_error( if args.openai_api_type:
"No OpenAI API key provided. Use --openai-api-key or export OPENAI_API_KEY." os.environ["AZURE_API_TYPE"] = args.openai_api_type
) if args.openai_organization_id:
os.environ["OPENAI_ORGANIZATION"] = args.openai_organization_id
res = litellm.validate_environment(args.model)
missing_keys = res.get("missing_keys")
if missing_keys:
io.tool_error(f"To use model {args.model}, please set these environment variables:")
for key in missing_keys:
io.tool_error(f"- {key}")
return 1 return 1
if args.openai_api_type == "azure": main_model = models.Model.create(args.model, None)
client = openai.AzureOpenAI(
api_key=args.openai_api_key,
azure_endpoint=args.openai_api_base,
api_version=args.openai_api_version,
azure_deployment=args.openai_api_deployment_id,
)
else:
kwargs = dict()
if args.openai_api_base:
kwargs["base_url"] = args.openai_api_base
if "openrouter.ai" in args.openai_api_base:
kwargs["default_headers"] = {
"HTTP-Referer": "http://aider.chat",
"X-Title": "Aider",
}
if args.openai_organization_id:
kwargs["organization"] = args.openai_organization_id
client = openai.OpenAI(api_key=args.openai_api_key, **kwargs)
main_model = models.Model.create(args.model, client)
try: try:
coder = Coder.create( coder = Coder.create(
@ -599,7 +591,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
edit_format=args.edit_format, edit_format=args.edit_format,
io=io, io=io,
skip_model_availabily_check=args.skip_model_availability_check, skip_model_availabily_check=args.skip_model_availability_check,
client=client, client=None,
## ##
fnames=fnames, fnames=fnames,
git_dname=git_dname, git_dname=git_dname,

View file

@ -18,15 +18,6 @@ class Model:
prompt_price = None prompt_price = None
completion_price = None completion_price = None
@classmethod
def create(cls, name, client=None):
from .openai import OpenAIModel
from .openrouter import OpenRouterModel
if client and client.base_url.host == "openrouter.ai":
return OpenRouterModel(client, name)
return OpenAIModel(name)
def __str__(self): def __str__(self):
return self.name return self.name

View file

@ -3,13 +3,14 @@ import json
import backoff import backoff
import httpx import httpx
import litellm
import openai import openai
# from diskcache import Cache # from diskcache import Cache
from openai import APIConnectionError, InternalServerError, RateLimitError from openai import APIConnectionError, InternalServerError, RateLimitError
from aider.utils import is_gpt4_with_openai_base_url
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.utils import is_gpt4_with_openai_base_url
CACHE_PATH = "~/.aider.send.cache.v1" CACHE_PATH = "~/.aider.send.cache.v1"
CACHE = None CACHE = None
@ -30,9 +31,6 @@ CACHE = None
), ),
) )
def send_with_retries(client, model_name, messages, functions, stream): def send_with_retries(client, model_name, messages, functions, stream):
if not client:
raise ValueError("No openai client provided")
kwargs = dict( kwargs = dict(
model=model_name, model=model_name,
messages=messages, messages=messages,
@ -42,11 +40,14 @@ def send_with_retries(client, model_name, messages, functions, stream):
if functions is not None: if functions is not None:
kwargs["functions"] = functions kwargs["functions"] = functions
# Check conditions to switch to gpt-4-vision-preview or strip out image_url messages # Check conditions to switch to gpt-4-vision-preview or strip out image_url messages
if client and is_gpt4_with_openai_base_url(model_name, client): if client and is_gpt4_with_openai_base_url(model_name, client):
if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages): if any(
kwargs['model'] = "gpt-4-vision-preview" isinstance(msg.get("content"), list)
and any("image_url" in item for item in msg.get("content") if isinstance(item, dict))
for msg in messages
):
kwargs["model"] = "gpt-4-vision-preview"
# gpt-4-vision is limited to max tokens of 4096 # gpt-4-vision is limited to max tokens of 4096
kwargs["max_tokens"] = 4096 kwargs["max_tokens"] = 4096
@ -58,7 +59,7 @@ def send_with_retries(client, model_name, messages, functions, stream):
if not stream and CACHE is not None and key in CACHE: if not stream and CACHE is not None and key in CACHE:
return hash_object, CACHE[key] return hash_object, CACHE[key]
res = client.chat.completions.create(**kwargs) res = litellm.completion(**kwargs)
if not stream and CACHE is not None: if not stream and CACHE is not None:
CACHE[key] = res CACHE[key] = res