Auto switch to gpt-4-vision-preview if image files added to context

This commit is contained in:
Joshua Vial 2023-11-29 21:20:29 +13:00
parent fd34766aa9
commit d8f33a8124
7 changed files with 128 additions and 11 deletions

View file

@ -38,6 +38,8 @@ class ExhaustedContextWindow(Exception):
def wrap_fence(name): def wrap_fence(name):
return f"<{name}>", f"</{name}>" return f"<{name}>", f"</{name}>"
#NOTE currently duplicated in io.py
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
class Coder: class Coder:
abs_fnames = None abs_fnames = None
@ -283,12 +285,13 @@ class Coder:
prompt = "" prompt = ""
for fname, content in self.get_abs_fnames_content(): for fname, content in self.get_abs_fnames_content():
relative_fname = self.get_rel_fname(fname) if not any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
prompt += "\n" relative_fname = self.get_rel_fname(fname)
prompt += relative_fname prompt += "\n"
prompt += f"\n{self.fence[0]}\n" prompt += relative_fname
prompt += content prompt += f"\n{self.fence[0]}\n"
prompt += f"{self.fence[1]}\n" prompt += content
prompt += f"{self.fence[1]}\n"
return prompt return prompt
@ -321,8 +324,33 @@ class Coder:
dict(role="assistant", content="Ok."), dict(role="assistant", content="Ok."),
] ]
images_message = self.get_images_message()
if images_message is not None:
files_messages.append(images_message)
return files_messages return files_messages
def get_images_message(self):
image_messages = []
for fname, content in self.get_abs_fnames_content():
if any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
image_url = f"data:image/{Path(fname).suffix.lstrip('.')};base64,{content}"
image_messages.append({
"type": "image_url",
"image_url": {
"url": image_url,
"detail": "high"
}
})
if not image_messages:
return None
return {
"role": "user",
"content": image_messages
}
def run(self, with_message=None): def run(self, with_message=None):
while True: while True:
try: try:

View file

@ -10,6 +10,8 @@ from aider import prompts, voice
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
#NOTE currently duplicated in io.py and base_coder.py
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
class Commands: class Commands:
voice = None voice = None
@ -138,9 +140,13 @@ class Commands:
for fname in self.coder.abs_fnames: for fname in self.coder.abs_fnames:
relative_fname = self.coder.get_rel_fname(fname) relative_fname = self.coder.get_rel_fname(fname)
content = self.io.read_text(fname) content = self.io.read_text(fname)
# approximate if any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS):
content = f"{relative_fname}\n```\n" + content + "```\n" # If the file is an image, use the token_count_for_image method
tokens = self.coder.main_model.token_count(content) tokens = self.coder.main_model.token_count_for_image(fname)
else:
# approximate
content = f"{relative_fname}\n```\n" + content + "```\n"
tokens = self.coder.main_model.token_count(content)
res.append((tokens, f"{relative_fname}", "use /drop to drop from chat")) res.append((tokens, f"{relative_fname}", "use /drop to drop from chat"))
self.io.tool_output("Approximate context window usage, in tokens:") self.io.tool_output("Approximate context window usage, in tokens:")
@ -167,7 +173,9 @@ class Commands:
self.io.tool_output("=" * (width + cost_width + 1)) self.io.tool_output("=" * (width + cost_width + 1))
self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total") self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total")
limit = self.coder.main_model.max_context_tokens # Check if any images are in the chat and override the max context window size if so
image_in_chat = any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS for relative_fname in self.coder.get_inchat_relative_files())
limit = 4096 if image_in_chat else self.coder.main_model.max_context_tokens
remaining = limit - total remaining = limit - total
if remaining > 1024: if remaining > 1024:
self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window") self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window")

View file

@ -1,6 +1,7 @@
import os import os
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
import base64
from pathlib import Path from pathlib import Path
from prompt_toolkit.completion import Completer, Completion from prompt_toolkit.completion import Completer, Completion
@ -17,6 +18,9 @@ from rich.text import Text
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
#QUESTION what image extensions do we want to support?
#QUESTION where should this live? Currently duplicated in base_coder
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
class AutoCompleter(Completer): class AutoCompleter(Completer):
def __init__(self, root, rel_fnames, addable_rel_fnames, commands, encoding): def __init__(self, root, rel_fnames, addable_rel_fnames, commands, encoding):
@ -139,7 +143,27 @@ class InputOutput:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n") self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
def read_image(self, filename):
try:
with open(str(filename), "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode('utf-8')
except FileNotFoundError:
self.tool_error(f"{filename}: file not found error")
return
except IsADirectoryError:
self.tool_error(f"{filename}: is a directory")
return
except Exception as e:
self.tool_error(f"{filename}: {e}")
return
def read_text(self, filename): def read_text(self, filename):
file_extension = Path(filename).suffix.lower()
if file_extension in IMAGE_EXTENSIONS:
return self.read_image(filename)
try: try:
with open(str(filename), "r", encoding=self.encoding) as f: with open(str(filename), "r", encoding=self.encoding) as f:
return f.read() return f.read()

View file

@ -1,6 +1,8 @@
import json import json
import math
import openai import openai
from PIL import Image
class Model: class Model:
@ -51,3 +53,46 @@ class Model:
msgs = json.dumps(messages) msgs = json.dumps(messages)
return len(self.tokenizer.encode(msgs)) return len(self.tokenizer.encode(msgs))
def token_count_for_image(self, fname):
"""
Calculate the token cost for an image assuming high detail.
The token cost is determined by the size of the image.
:param fname: The filename of the image.
:return: The token cost for the image.
"""
# Placeholder for image size retrieval logic
# TODO: Implement the logic to retrieve the image size from the file
width, height = self.get_image_size(fname)
# If the image is larger than 2048 in any dimension, scale it down to fit within 2048x2048
max_dimension = max(width, height)
if max_dimension > 2048:
scale_factor = 2048 / max_dimension
width = int(width * scale_factor)
height = int(height * scale_factor)
# Scale the image such that the shortest side is 768 pixels long
min_dimension = min(width, height)
scale_factor = 768 / min_dimension
width = int(width * scale_factor)
height = int(height * scale_factor)
# Calculate the number of 512x512 tiles needed to cover the image
tiles_width = math.ceil(width / 512)
tiles_height = math.ceil(height / 512)
num_tiles = tiles_width * tiles_height
# Each tile costs 170 tokens, and there's an additional fixed cost of 85 tokens
token_cost = num_tiles * 170 + 85
return token_cost
def get_image_size(self, fname):
"""
Retrieve the size of an image.
:param fname: The filename of the image.
:return: A tuple (width, height) representing the image size in pixels.
"""
with Image.open(fname) as img:
return img.size

View file

@ -53,6 +53,13 @@ def send_with_retries(model_name, messages, functions, stream):
if "openrouter.ai" in openai.api_base: if "openrouter.ai" in openai.api_base:
kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"} kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
# Check conditions to switch to gpt-4-vision-preview
if "openrouter.ai" not in openai.api_base and model_name.startswith("gpt-4"):
if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages):
kwargs['model'] = "gpt-4-vision-preview"
# looks like gpt-4-vision is limited to max tokens of 4096
kwargs["max_tokens"] = 4096
key = json.dumps(kwargs, sort_keys=True).encode() key = json.dumps(kwargs, sort_keys=True).encode()
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes # Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes

View file

@ -16,7 +16,11 @@ def show_messages(messages, title=None, functions=None):
for msg in messages: for msg in messages:
role = msg["role"].upper() role = msg["role"].upper()
content = msg.get("content") content = msg.get("content")
if content: if isinstance(content, list): # Handle list content (e.g., image messages)
for item in content:
if isinstance(item, dict) and "image_url" in item:
print(role, "Image URL:", item["image_url"]["url"])
elif isinstance(content, str): # Handle string content
for line in content.splitlines(): for line in content.splitlines():
print(role, line) print(role, line)
content = msg.get("function_call") content = msg.get("function_call")

View file

@ -16,3 +16,4 @@ packaging
sounddevice sounddevice
soundfile soundfile
PyYAML PyYAML
Pillow