handle switching to gpt4-vision-preview

This commit is contained in:
Joshua Vial 2023-12-11 21:13:07 +13:00
parent 90fb538015
commit c919f9f0c6
4 changed files with 14 additions and 12 deletions

View file

@ -10,8 +10,7 @@ from aider import prompts, voice
from .dump import dump # noqa: F401
#NOTE currently duplicated in io.py and base_coder.py
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
from aider.utils import is_image_file
class Commands:
voice = None
@ -140,7 +139,7 @@ class Commands:
for fname in self.coder.abs_fnames:
relative_fname = self.coder.get_rel_fname(fname)
content = self.io.read_text(fname)
if any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS):
if is_image_file(relative_fname):
# If the file is an image, use the token_count_for_image method
tokens = self.coder.main_model.token_count_for_image(fname)
else:

View file

@ -33,7 +33,6 @@ class OpenRouterModel(Model):
raise ValueError(f"invalid openrouter model: {name}")
# TODO run benchmarks and figure out which models support which edit-formats
def edit_format_for_model(name):
if any(str in name for str in ["gpt-4", "claude-2"]):
return "diff"

View file

@ -43,18 +43,10 @@ def send_with_retries(client, model_name, messages, functions, stream):
# Check conditions to switch to gpt-4-vision-preview or strip out image_url messages
if client and model_name.startswith("gpt-4") and "api.openai.com" in client.base_url.host:
print('switch model')
if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages):
kwargs['model'] = "gpt-4-vision-preview"
# gpt-4-vision is limited to max tokens of 4096
kwargs["max_tokens"] = 4096
else:
# Strip out any image_url messages if not using gpt-4-vision-preview
print('strip img')
messages = [
{k: v for k, v in msg.items() if k != "content" or not any(isinstance(item, dict) and "image_url" in item for item in v)}
for msg in messages if isinstance(msg.get("content"), list)
] + [msg for msg in messages if not isinstance(msg.get("content"), list)]
key = json.dumps(kwargs, sort_keys=True).encode()

View file

@ -1,7 +1,19 @@
from pathlib import Path
# Set of image file extensions
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
from .dump import dump # noqa: F401
def is_image_file(file_name):
"""
Check if the given file name has an image file extension.
:param file_name: The name of the file to check.
:return: True if the file is an image, False otherwise.
"""
return any(file_name.endswith(ext) for ext in IMAGE_EXTENSIONS)
def safe_abs_path(res):
"Gives an abs path, which safely returns a full (not 8.3) windows path"