mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 14:25:00 +00:00
Auto switch to gpt-4-vision-preview if image files added to context
This commit is contained in:
parent
fd34766aa9
commit
d8f33a8124
7 changed files with 128 additions and 11 deletions
|
@ -38,6 +38,8 @@ class ExhaustedContextWindow(Exception):
|
||||||
def wrap_fence(name):
|
def wrap_fence(name):
|
||||||
return f"<{name}>", f"</{name}>"
|
return f"<{name}>", f"</{name}>"
|
||||||
|
|
||||||
|
#NOTE currently duplicated in io.py
|
||||||
|
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
||||||
|
|
||||||
class Coder:
|
class Coder:
|
||||||
abs_fnames = None
|
abs_fnames = None
|
||||||
|
@ -283,12 +285,13 @@ class Coder:
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
for fname, content in self.get_abs_fnames_content():
|
for fname, content in self.get_abs_fnames_content():
|
||||||
relative_fname = self.get_rel_fname(fname)
|
if not any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
|
||||||
prompt += "\n"
|
relative_fname = self.get_rel_fname(fname)
|
||||||
prompt += relative_fname
|
prompt += "\n"
|
||||||
prompt += f"\n{self.fence[0]}\n"
|
prompt += relative_fname
|
||||||
prompt += content
|
prompt += f"\n{self.fence[0]}\n"
|
||||||
prompt += f"{self.fence[1]}\n"
|
prompt += content
|
||||||
|
prompt += f"{self.fence[1]}\n"
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
@ -321,8 +324,33 @@ class Coder:
|
||||||
dict(role="assistant", content="Ok."),
|
dict(role="assistant", content="Ok."),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
images_message = self.get_images_message()
|
||||||
|
if images_message is not None:
|
||||||
|
files_messages.append(images_message)
|
||||||
|
|
||||||
return files_messages
|
return files_messages
|
||||||
|
|
||||||
|
def get_images_message(self):
|
||||||
|
image_messages = []
|
||||||
|
for fname, content in self.get_abs_fnames_content():
|
||||||
|
if any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
|
||||||
|
image_url = f"data:image/{Path(fname).suffix.lstrip('.')};base64,{content}"
|
||||||
|
image_messages.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url,
|
||||||
|
"detail": "high"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if not image_messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "user",
|
||||||
|
"content": image_messages
|
||||||
|
}
|
||||||
|
|
||||||
def run(self, with_message=None):
|
def run(self, with_message=None):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -10,6 +10,8 @@ from aider import prompts, voice
|
||||||
|
|
||||||
from .dump import dump # noqa: F401
|
from .dump import dump # noqa: F401
|
||||||
|
|
||||||
|
#NOTE currently duplicated in io.py and base_coder.py
|
||||||
|
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
||||||
|
|
||||||
class Commands:
|
class Commands:
|
||||||
voice = None
|
voice = None
|
||||||
|
@ -138,9 +140,13 @@ class Commands:
|
||||||
for fname in self.coder.abs_fnames:
|
for fname in self.coder.abs_fnames:
|
||||||
relative_fname = self.coder.get_rel_fname(fname)
|
relative_fname = self.coder.get_rel_fname(fname)
|
||||||
content = self.io.read_text(fname)
|
content = self.io.read_text(fname)
|
||||||
# approximate
|
if any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS):
|
||||||
content = f"{relative_fname}\n```\n" + content + "```\n"
|
# If the file is an image, use the token_count_for_image method
|
||||||
tokens = self.coder.main_model.token_count(content)
|
tokens = self.coder.main_model.token_count_for_image(fname)
|
||||||
|
else:
|
||||||
|
# approximate
|
||||||
|
content = f"{relative_fname}\n```\n" + content + "```\n"
|
||||||
|
tokens = self.coder.main_model.token_count(content)
|
||||||
res.append((tokens, f"{relative_fname}", "use /drop to drop from chat"))
|
res.append((tokens, f"{relative_fname}", "use /drop to drop from chat"))
|
||||||
|
|
||||||
self.io.tool_output("Approximate context window usage, in tokens:")
|
self.io.tool_output("Approximate context window usage, in tokens:")
|
||||||
|
@ -167,7 +173,9 @@ class Commands:
|
||||||
self.io.tool_output("=" * (width + cost_width + 1))
|
self.io.tool_output("=" * (width + cost_width + 1))
|
||||||
self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total")
|
self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total")
|
||||||
|
|
||||||
limit = self.coder.main_model.max_context_tokens
|
# Check if any images are in the chat and override the max context window size if so
|
||||||
|
image_in_chat = any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS for relative_fname in self.coder.get_inchat_relative_files())
|
||||||
|
limit = 4096 if image_in_chat else self.coder.main_model.max_context_tokens
|
||||||
remaining = limit - total
|
remaining = limit - total
|
||||||
if remaining > 1024:
|
if remaining > 1024:
|
||||||
self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window")
|
self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window")
|
||||||
|
|
24
aider/io.py
24
aider/io.py
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import base64
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from prompt_toolkit.completion import Completer, Completion
|
from prompt_toolkit.completion import Completer, Completion
|
||||||
|
@ -17,6 +18,9 @@ from rich.text import Text
|
||||||
|
|
||||||
from .dump import dump # noqa: F401
|
from .dump import dump # noqa: F401
|
||||||
|
|
||||||
|
#QUESTION what image extensions do we want to support?
|
||||||
|
#QUESTION where should this live? Currently duplicated in base_coder
|
||||||
|
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
||||||
|
|
||||||
class AutoCompleter(Completer):
|
class AutoCompleter(Completer):
|
||||||
def __init__(self, root, rel_fnames, addable_rel_fnames, commands, encoding):
|
def __init__(self, root, rel_fnames, addable_rel_fnames, commands, encoding):
|
||||||
|
@ -139,7 +143,27 @@ class InputOutput:
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
|
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
def read_image(self, filename):
|
||||||
|
try:
|
||||||
|
with open(str(filename), "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return encoded_string.decode('utf-8')
|
||||||
|
except FileNotFoundError:
|
||||||
|
self.tool_error(f"{filename}: file not found error")
|
||||||
|
return
|
||||||
|
except IsADirectoryError:
|
||||||
|
self.tool_error(f"{filename}: is a directory")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
self.tool_error(f"{filename}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
def read_text(self, filename):
|
def read_text(self, filename):
|
||||||
|
file_extension = Path(filename).suffix.lower()
|
||||||
|
if file_extension in IMAGE_EXTENSIONS:
|
||||||
|
return self.read_image(filename)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(str(filename), "r", encoding=self.encoding) as f:
|
with open(str(filename), "r", encoding=self.encoding) as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
|
@ -51,3 +53,46 @@ class Model:
|
||||||
msgs = json.dumps(messages)
|
msgs = json.dumps(messages)
|
||||||
|
|
||||||
return len(self.tokenizer.encode(msgs))
|
return len(self.tokenizer.encode(msgs))
|
||||||
|
|
||||||
|
def token_count_for_image(self, fname):
|
||||||
|
"""
|
||||||
|
Calculate the token cost for an image assuming high detail.
|
||||||
|
The token cost is determined by the size of the image.
|
||||||
|
:param fname: The filename of the image.
|
||||||
|
:return: The token cost for the image.
|
||||||
|
"""
|
||||||
|
# Placeholder for image size retrieval logic
|
||||||
|
# TODO: Implement the logic to retrieve the image size from the file
|
||||||
|
width, height = self.get_image_size(fname)
|
||||||
|
|
||||||
|
# If the image is larger than 2048 in any dimension, scale it down to fit within 2048x2048
|
||||||
|
max_dimension = max(width, height)
|
||||||
|
if max_dimension > 2048:
|
||||||
|
scale_factor = 2048 / max_dimension
|
||||||
|
width = int(width * scale_factor)
|
||||||
|
height = int(height * scale_factor)
|
||||||
|
|
||||||
|
# Scale the image such that the shortest side is 768 pixels long
|
||||||
|
min_dimension = min(width, height)
|
||||||
|
scale_factor = 768 / min_dimension
|
||||||
|
width = int(width * scale_factor)
|
||||||
|
height = int(height * scale_factor)
|
||||||
|
|
||||||
|
# Calculate the number of 512x512 tiles needed to cover the image
|
||||||
|
tiles_width = math.ceil(width / 512)
|
||||||
|
tiles_height = math.ceil(height / 512)
|
||||||
|
num_tiles = tiles_width * tiles_height
|
||||||
|
|
||||||
|
# Each tile costs 170 tokens, and there's an additional fixed cost of 85 tokens
|
||||||
|
token_cost = num_tiles * 170 + 85
|
||||||
|
return token_cost
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_size(self, fname):
|
||||||
|
"""
|
||||||
|
Retrieve the size of an image.
|
||||||
|
:param fname: The filename of the image.
|
||||||
|
:return: A tuple (width, height) representing the image size in pixels.
|
||||||
|
"""
|
||||||
|
with Image.open(fname) as img:
|
||||||
|
return img.size
|
||||||
|
|
|
@ -53,6 +53,13 @@ def send_with_retries(model_name, messages, functions, stream):
|
||||||
if "openrouter.ai" in openai.api_base:
|
if "openrouter.ai" in openai.api_base:
|
||||||
kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
|
kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"}
|
||||||
|
|
||||||
|
# Check conditions to switch to gpt-4-vision-preview
|
||||||
|
if "openrouter.ai" not in openai.api_base and model_name.startswith("gpt-4"):
|
||||||
|
if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages):
|
||||||
|
kwargs['model'] = "gpt-4-vision-preview"
|
||||||
|
# looks like gpt-4-vision is limited to max tokens of 4096
|
||||||
|
kwargs["max_tokens"] = 4096
|
||||||
|
|
||||||
key = json.dumps(kwargs, sort_keys=True).encode()
|
key = json.dumps(kwargs, sort_keys=True).encode()
|
||||||
|
|
||||||
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
|
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
|
||||||
|
|
|
@ -16,7 +16,11 @@ def show_messages(messages, title=None, functions=None):
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
role = msg["role"].upper()
|
role = msg["role"].upper()
|
||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
if content:
|
if isinstance(content, list): # Handle list content (e.g., image messages)
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and "image_url" in item:
|
||||||
|
print(role, "Image URL:", item["image_url"]["url"])
|
||||||
|
elif isinstance(content, str): # Handle string content
|
||||||
for line in content.splitlines():
|
for line in content.splitlines():
|
||||||
print(role, line)
|
print(role, line)
|
||||||
content = msg.get("function_call")
|
content = msg.get("function_call")
|
||||||
|
|
|
@ -16,3 +16,4 @@ packaging
|
||||||
sounddevice
|
sounddevice
|
||||||
soundfile
|
soundfile
|
||||||
PyYAML
|
PyYAML
|
||||||
|
Pillow
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue