Auto switch to gpt-4-vision-preview if image files added to context

This commit is contained in:
Joshua Vial 2023-11-29 21:20:29 +13:00
parent fd34766aa9
commit d8f33a8124
7 changed files with 128 additions and 11 deletions

View file

@ -38,6 +38,8 @@ class ExhaustedContextWindow(Exception):
def wrap_fence(name):
return f"<{name}>", f"</{name}>"
#NOTE currently duplicated in io.py
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
class Coder:
abs_fnames = None
@ -283,6 +285,7 @@ class Coder:
prompt = ""
for fname, content in self.get_abs_fnames_content():
if not any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
relative_fname = self.get_rel_fname(fname)
prompt += "\n"
prompt += relative_fname
@ -321,8 +324,33 @@ class Coder:
dict(role="assistant", content="Ok."),
]
images_message = self.get_images_message()
if images_message is not None:
files_messages.append(images_message)
return files_messages
def get_images_message(self):
image_messages = []
for fname, content in self.get_abs_fnames_content():
if any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
image_url = f"data:image/{Path(fname).suffix.lstrip('.')};base64,{content}"
image_messages.append({
"type": "image_url",
"image_url": {
"url": image_url,
"detail": "high"
}
})
if not image_messages:
return None
return {
"role": "user",
"content": image_messages
}
def run(self, with_message=None):
while True:
try:

View file

@ -10,6 +10,8 @@ from aider import prompts, voice
from .dump import dump # noqa: F401
#NOTE currently duplicated in io.py and base_coder.py
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
class Commands:
voice = None
@ -138,6 +140,10 @@ class Commands:
for fname in self.coder.abs_fnames:
relative_fname = self.coder.get_rel_fname(fname)
content = self.io.read_text(fname)
if any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS):
# If the file is an image, use the token_count_for_image method
tokens = self.coder.main_model.token_count_for_image(fname)
else:
# approximate
content = f"{relative_fname}\n```\n" + content + "```\n"
tokens = self.coder.main_model.token_count(content)
@ -167,7 +173,9 @@ class Commands:
self.io.tool_output("=" * (width + cost_width + 1))
self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total")
limit = self.coder.main_model.max_context_tokens
# Check if any images are in the chat and override the max context window size if so
image_in_chat = any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS for relative_fname in self.coder.get_inchat_relative_files())
limit = 4096 if image_in_chat else self.coder.main_model.max_context_tokens
remaining = limit - total
if remaining > 1024:
self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window")

View file

@ -1,6 +1,7 @@
import os
from collections import defaultdict
from datetime import datetime
import base64
from pathlib import Path
from prompt_toolkit.completion import Completer, Completion
@ -17,6 +18,9 @@ from rich.text import Text
from .dump import dump # noqa: F401
#QUESTION what image extensions do we want to support?
#QUESTION where should this live? Currently duplicated in base_coder
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
class AutoCompleter(Completer):
def __init__(self, root, rel_fnames, addable_rel_fnames, commands, encoding):
@ -139,7 +143,27 @@ class InputOutput:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
def read_image(self, filename):
try:
with open(str(filename), "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode('utf-8')
except FileNotFoundError:
self.tool_error(f"{filename}: file not found error")
return
except IsADirectoryError:
self.tool_error(f"{filename}: is a directory")
return
except Exception as e:
self.tool_error(f"{filename}: {e}")
return
def read_text(self, filename):
file_extension = Path(filename).suffix.lower()
if file_extension in IMAGE_EXTENSIONS:
return self.read_image(filename)
try:
with open(str(filename), "r", encoding=self.encoding) as f:
return f.read()

View file

@ -1,6 +1,8 @@
import json
import math
import openai
from PIL import Image
class Model:
@ -51,3 +53,46 @@ class Model:
msgs = json.dumps(messages)
return len(self.tokenizer.encode(msgs))
def token_count_for_image(self, fname):
"""
Calculate the token cost for an image assuming high detail.
The token cost is determined by the size of the image.
:param fname: The filename of the image.
:return: The token cost for the image.
"""
# Placeholder for image size retrieval logic
# TODO: Implement the logic to retrieve the image size from the file
width, height = self.get_image_size(fname)
# If the image is larger than 2048 in any dimension, scale it down to fit within 2048x2048
max_dimension = max(width, height)
if max_dimension > 2048:
scale_factor = 2048 / max_dimension
width = int(width * scale_factor)
height = int(height * scale_factor)
# Scale the image such that the shortest side is 768 pixels long
min_dimension = min(width, height)
scale_factor = 768 / min_dimension
width = int(width * scale_factor)
height = int(height * scale_factor)
# Calculate the number of 512x512 tiles needed to cover the image
tiles_width = math.ceil(width / 512)
tiles_height = math.ceil(height / 512)
num_tiles = tiles_width * tiles_height
# Each tile costs 170 tokens, and there's an additional fixed cost of 85 tokens
token_cost = num_tiles * 170 + 85
return token_cost
def get_image_size(self, fname):
"""
Retrieve the size of an image.
:param fname: The filename of the image.
:return: A tuple (width, height) representing the image size in pixels.
"""
with Image.open(fname) as img:
return img.size

View file

@ -53,6 +53,13 @@ def send_with_retries(model_name, messages, functions, stream):
if "openrouter.ai" in openai.api_base:
kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
# Check conditions to switch to gpt-4-vision-preview
if "openrouter.ai" not in openai.api_base and model_name.startswith("gpt-4"):
if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages):
kwargs['model'] = "gpt-4-vision-preview"
# looks like gpt-4-vision is limited to max tokens of 4096
kwargs["max_tokens"] = 4096
key = json.dumps(kwargs, sort_keys=True).encode()
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes

View file

@ -16,7 +16,11 @@ def show_messages(messages, title=None, functions=None):
for msg in messages:
role = msg["role"].upper()
content = msg.get("content")
if content:
if isinstance(content, list): # Handle list content (e.g., image messages)
for item in content:
if isinstance(item, dict) and "image_url" in item:
print(role, "Image URL:", item["image_url"]["url"])
elif isinstance(content, str): # Handle string content
for line in content.splitlines():
print(role, line)
content = msg.get("function_call")

View file

@ -16,3 +16,4 @@ packaging
sounddevice
soundfile
PyYAML
Pillow