mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 14:25:00 +00:00
Auto switch to gpt-4-vision-preview if image files added to context
This commit is contained in:
parent
fd34766aa9
commit
d8f33a8124
7 changed files with 128 additions and 11 deletions
|
@ -38,6 +38,8 @@ class ExhaustedContextWindow(Exception):
|
|||
def wrap_fence(name):
|
||||
return f"<{name}>", f"</{name}>"
|
||||
|
||||
#NOTE currently duplicated in io.py
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
|
||||
class Coder:
|
||||
abs_fnames = None
|
||||
|
@ -283,6 +285,7 @@ class Coder:
|
|||
|
||||
prompt = ""
|
||||
for fname, content in self.get_abs_fnames_content():
|
||||
if not any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
|
||||
relative_fname = self.get_rel_fname(fname)
|
||||
prompt += "\n"
|
||||
prompt += relative_fname
|
||||
|
@ -321,8 +324,33 @@ class Coder:
|
|||
dict(role="assistant", content="Ok."),
|
||||
]
|
||||
|
||||
images_message = self.get_images_message()
|
||||
if images_message is not None:
|
||||
files_messages.append(images_message)
|
||||
|
||||
return files_messages
|
||||
|
||||
def get_images_message(self):
|
||||
image_messages = []
|
||||
for fname, content in self.get_abs_fnames_content():
|
||||
if any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
|
||||
image_url = f"data:image/{Path(fname).suffix.lstrip('.')};base64,{content}"
|
||||
image_messages.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
"detail": "high"
|
||||
}
|
||||
})
|
||||
|
||||
if not image_messages:
|
||||
return None
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": image_messages
|
||||
}
|
||||
|
||||
def run(self, with_message=None):
|
||||
while True:
|
||||
try:
|
||||
|
|
|
@ -10,6 +10,8 @@ from aider import prompts, voice
|
|||
|
||||
from .dump import dump # noqa: F401
|
||||
|
||||
#NOTE currently duplicated in io.py and base_coder.py
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
|
||||
class Commands:
|
||||
voice = None
|
||||
|
@ -138,6 +140,10 @@ class Commands:
|
|||
for fname in self.coder.abs_fnames:
|
||||
relative_fname = self.coder.get_rel_fname(fname)
|
||||
content = self.io.read_text(fname)
|
||||
if any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS):
|
||||
# If the file is an image, use the token_count_for_image method
|
||||
tokens = self.coder.main_model.token_count_for_image(fname)
|
||||
else:
|
||||
# approximate
|
||||
content = f"{relative_fname}\n```\n" + content + "```\n"
|
||||
tokens = self.coder.main_model.token_count(content)
|
||||
|
@ -167,7 +173,9 @@ class Commands:
|
|||
self.io.tool_output("=" * (width + cost_width + 1))
|
||||
self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total")
|
||||
|
||||
limit = self.coder.main_model.max_context_tokens
|
||||
# Check if any images are in the chat and override the max context window size if so
|
||||
image_in_chat = any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS for relative_fname in self.coder.get_inchat_relative_files())
|
||||
limit = 4096 if image_in_chat else self.coder.main_model.max_context_tokens
|
||||
remaining = limit - total
|
||||
if remaining > 1024:
|
||||
self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window")
|
||||
|
|
24
aider/io.py
24
aider/io.py
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
@ -17,6 +18,9 @@ from rich.text import Text
|
|||
|
||||
from .dump import dump # noqa: F401
|
||||
|
||||
#QUESTION what image extensions do we want to support?
|
||||
#QUESTION where should this live? Currently duplicated in base_coder
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
|
||||
class AutoCompleter(Completer):
|
||||
def __init__(self, root, rel_fnames, addable_rel_fnames, commands, encoding):
|
||||
|
@ -139,7 +143,27 @@ class InputOutput:
|
|||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
|
||||
|
||||
|
||||
def read_image(self, filename):
|
||||
try:
|
||||
with open(str(filename), "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return encoded_string.decode('utf-8')
|
||||
except FileNotFoundError:
|
||||
self.tool_error(f"{filename}: file not found error")
|
||||
return
|
||||
except IsADirectoryError:
|
||||
self.tool_error(f"{filename}: is a directory")
|
||||
return
|
||||
except Exception as e:
|
||||
self.tool_error(f"{filename}: {e}")
|
||||
return
|
||||
|
||||
def read_text(self, filename):
|
||||
file_extension = Path(filename).suffix.lower()
|
||||
if file_extension in IMAGE_EXTENSIONS:
|
||||
return self.read_image(filename)
|
||||
|
||||
try:
|
||||
with open(str(filename), "r", encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import json
|
||||
import math
|
||||
|
||||
import openai
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Model:
|
||||
|
@ -51,3 +53,46 @@ class Model:
|
|||
msgs = json.dumps(messages)
|
||||
|
||||
return len(self.tokenizer.encode(msgs))
|
||||
|
||||
def token_count_for_image(self, fname):
|
||||
"""
|
||||
Calculate the token cost for an image assuming high detail.
|
||||
The token cost is determined by the size of the image.
|
||||
:param fname: The filename of the image.
|
||||
:return: The token cost for the image.
|
||||
"""
|
||||
# Placeholder for image size retrieval logic
|
||||
# TODO: Implement the logic to retrieve the image size from the file
|
||||
width, height = self.get_image_size(fname)
|
||||
|
||||
# If the image is larger than 2048 in any dimension, scale it down to fit within 2048x2048
|
||||
max_dimension = max(width, height)
|
||||
if max_dimension > 2048:
|
||||
scale_factor = 2048 / max_dimension
|
||||
width = int(width * scale_factor)
|
||||
height = int(height * scale_factor)
|
||||
|
||||
# Scale the image such that the shortest side is 768 pixels long
|
||||
min_dimension = min(width, height)
|
||||
scale_factor = 768 / min_dimension
|
||||
width = int(width * scale_factor)
|
||||
height = int(height * scale_factor)
|
||||
|
||||
# Calculate the number of 512x512 tiles needed to cover the image
|
||||
tiles_width = math.ceil(width / 512)
|
||||
tiles_height = math.ceil(height / 512)
|
||||
num_tiles = tiles_width * tiles_height
|
||||
|
||||
# Each tile costs 170 tokens, and there's an additional fixed cost of 85 tokens
|
||||
token_cost = num_tiles * 170 + 85
|
||||
return token_cost
|
||||
|
||||
|
||||
def get_image_size(self, fname):
|
||||
"""
|
||||
Retrieve the size of an image.
|
||||
:param fname: The filename of the image.
|
||||
:return: A tuple (width, height) representing the image size in pixels.
|
||||
"""
|
||||
with Image.open(fname) as img:
|
||||
return img.size
|
||||
|
|
|
@ -53,6 +53,13 @@ def send_with_retries(model_name, messages, functions, stream):
|
|||
if "openrouter.ai" in openai.api_base:
|
||||
kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
|
||||
|
||||
# Check conditions to switch to gpt-4-vision-preview
|
||||
if "openrouter.ai" not in openai.api_base and model_name.startswith("gpt-4"):
|
||||
if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages):
|
||||
kwargs['model'] = "gpt-4-vision-preview"
|
||||
# looks like gpt-4-vision is limited to max tokens of 4096
|
||||
kwargs["max_tokens"] = 4096
|
||||
|
||||
key = json.dumps(kwargs, sort_keys=True).encode()
|
||||
|
||||
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
|
||||
|
|
|
@ -16,7 +16,11 @@ def show_messages(messages, title=None, functions=None):
|
|||
for msg in messages:
|
||||
role = msg["role"].upper()
|
||||
content = msg.get("content")
|
||||
if content:
|
||||
if isinstance(content, list): # Handle list content (e.g., image messages)
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "image_url" in item:
|
||||
print(role, "Image URL:", item["image_url"]["url"])
|
||||
elif isinstance(content, str): # Handle string content
|
||||
for line in content.splitlines():
|
||||
print(role, line)
|
||||
content = msg.get("function_call")
|
||||
|
|
|
@ -16,3 +16,4 @@ packaging
|
|||
sounddevice
|
||||
soundfile
|
||||
PyYAML
|
||||
Pillow
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue