diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 8aa4ecba1..fc01e9a33 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -38,6 +38,8 @@ class ExhaustedContextWindow(Exception): def wrap_fence(name): return f"<{name}>", f"" +#NOTE currently duplicated in io.py +IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'} class Coder: abs_fnames = None @@ -283,12 +285,13 @@ class Coder: prompt = "" for fname, content in self.get_abs_fnames_content(): - relative_fname = self.get_rel_fname(fname) - prompt += "\n" - prompt += relative_fname - prompt += f"\n{self.fence[0]}\n" - prompt += content - prompt += f"{self.fence[1]}\n" + if not any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS): + relative_fname = self.get_rel_fname(fname) + prompt += "\n" + prompt += relative_fname + prompt += f"\n{self.fence[0]}\n" + prompt += content + prompt += f"{self.fence[1]}\n" return prompt @@ -321,8 +324,33 @@ class Coder: dict(role="assistant", content="Ok."), ] + images_message = self.get_images_message() + if images_message is not None: + files_messages.append(images_message) + return files_messages + def get_images_message(self): + image_messages = [] + for fname, content in self.get_abs_fnames_content(): + if any(fname.lower().endswith(ext) for ext in IMAGE_EXTENSIONS): + image_url = f"data:image/{Path(fname).suffix.lstrip('.')};base64,{content}" + image_messages.append({ + "type": "image_url", + "image_url": { + "url": image_url, + "detail": "high" + } + }) + + if not image_messages: + return None + + return { + "role": "user", + "content": image_messages + } + def run(self, with_message=None): while True: try: diff --git a/aider/commands.py b/aider/commands.py index c6a534a45..81045c246 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -10,6 +10,8 @@ from aider import prompts, voice from .dump import dump # noqa: F401 +#NOTE currently duplicated in io.py and base_coder.py +IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'} class Commands: voice = None @@ -138,9 +140,13 @@ class Commands: for fname in self.coder.abs_fnames: relative_fname = self.coder.get_rel_fname(fname) content = self.io.read_text(fname) - # approximate - content = f"{relative_fname}\n```\n" + content + "```\n" - tokens = self.coder.main_model.token_count(content) + if any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS): + # If the file is an image, use the token_count_for_image method + tokens = self.coder.main_model.token_count_for_image(fname) + else: + # approximate + content = f"{relative_fname}\n```\n" + content + "```\n" + tokens = self.coder.main_model.token_count(content) res.append((tokens, f"{relative_fname}", "use /drop to drop from chat")) self.io.tool_output("Approximate context window usage, in tokens:") @@ -167,7 +173,9 @@ class Commands: self.io.tool_output("=" * (width + cost_width + 1)) self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total") - limit = self.coder.main_model.max_context_tokens + # Check if any images are in the chat and override the max context window size if so + image_in_chat = any(relative_fname.endswith(ext) for ext in IMAGE_EXTENSIONS for relative_fname in self.coder.get_inchat_relative_files()) + limit = 4096 if image_in_chat else self.coder.main_model.max_context_tokens remaining = limit - total if remaining > 1024: self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window") diff --git a/aider/io.py b/aider/io.py index 1da68e83b..255efbd79 100644 --- a/aider/io.py +++ b/aider/io.py @@ -1,6 +1,7 @@ import os from collections import defaultdict from datetime import datetime +import base64 from pathlib import Path from prompt_toolkit.completion import Completer, Completion @@ -17,6 +18,9 @@ from rich.text import Text from .dump import dump # noqa: F401 +#QUESTION what image extensions do we want to support? +#QUESTION where should this live? Currently duplicated in base_coder +IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'} class AutoCompleter(Completer): def __init__(self, root, rel_fnames, addable_rel_fnames, commands, encoding): @@ -139,7 +143,27 @@ class InputOutput: current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.append_chat_history(f"\n# aider chat started at {current_time}\n\n") + + def read_image(self, filename): + try: + with open(str(filename), "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return encoded_string.decode('utf-8') + except FileNotFoundError: + self.tool_error(f"{filename}: file not found error") + return + except IsADirectoryError: + self.tool_error(f"{filename}: is a directory") + return + except Exception as e: + self.tool_error(f"{filename}: {e}") + return + def read_text(self, filename): + file_extension = Path(filename).suffix.lower() + if file_extension in IMAGE_EXTENSIONS: + return self.read_image(filename) + try: with open(str(filename), "r", encoding=self.encoding) as f: return f.read() diff --git a/aider/models/model.py b/aider/models/model.py index 9b1a3daab..3b394efb8 100644 --- a/aider/models/model.py +++ b/aider/models/model.py @@ -1,6 +1,8 @@ import json +import math import openai +from PIL import Image class Model: @@ -51,3 +53,46 @@ class Model: msgs = json.dumps(messages) return len(self.tokenizer.encode(msgs)) + + def token_count_for_image(self, fname): + """ + Calculate the token cost for an image assuming high detail. + The token cost is determined by the size of the image. + :param fname: The filename of the image. + :return: The token cost for the image. + """ + # Placeholder for image size retrieval logic + # TODO: Implement the logic to retrieve the image size from the file + width, height = self.get_image_size(fname) + + # If the image is larger than 2048 in any dimension, scale it down to fit within 2048x2048 + max_dimension = max(width, height) + if max_dimension > 2048: + scale_factor = 2048 / max_dimension + width = int(width * scale_factor) + height = int(height * scale_factor) + + # Scale the image such that the shortest side is 768 pixels long + min_dimension = min(width, height) + scale_factor = 768 / min_dimension + width = int(width * scale_factor) + height = int(height * scale_factor) + + # Calculate the number of 512x512 tiles needed to cover the image + tiles_width = math.ceil(width / 512) + tiles_height = math.ceil(height / 512) + num_tiles = tiles_width * tiles_height + + # Each tile costs 170 tokens, and there's an additional fixed cost of 85 tokens + token_cost = num_tiles * 170 + 85 + return token_cost + + + def get_image_size(self, fname): + """ + Retrieve the size of an image. + :param fname: The filename of the image. + :return: A tuple (width, height) representing the image size in pixels. + """ + with Image.open(fname) as img: + return img.size diff --git a/aider/sendchat.py b/aider/sendchat.py index 7c2994dcc..fb190f85c 100644 --- a/aider/sendchat.py +++ b/aider/sendchat.py @@ -53,6 +53,13 @@ def send_with_retries(model_name, messages, functions, stream): if "openrouter.ai" in openai.api_base: kwargs["headers"] = {"HTTP-Referer": "http://aider.chat", "X-Title": "Aider"} + # Check conditions to switch to gpt-4-vision-preview + if "openrouter.ai" not in openai.api_base and model_name.startswith("gpt-4"): + if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages): + kwargs['model'] = "gpt-4-vision-preview" + # looks like gpt-4-vision is limited to max tokens of 4096 + kwargs["max_tokens"] = 4096 + key = json.dumps(kwargs, sort_keys=True).encode() # Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes diff --git a/aider/utils.py b/aider/utils.py index 5147314cc..98d6a27b3 100644 --- a/aider/utils.py +++ b/aider/utils.py @@ -16,7 +16,11 @@ def show_messages(messages, title=None, functions=None): for msg in messages: role = msg["role"].upper() content = msg.get("content") - if content: + if isinstance(content, list): # Handle list content (e.g., image messages) + for item in content: + if isinstance(item, dict) and "image_url" in item: + print(role, "Image URL:", item["image_url"]["url"]) + elif isinstance(content, str): # Handle string content for line in content.splitlines(): print(role, line) content = msg.get("function_call") diff --git a/requirements.in b/requirements.in index d7fdf182a..fe07f3dae 100644 --- a/requirements.in +++ b/requirements.in @@ -16,3 +16,4 @@ packaging sounddevice soundfile PyYAML +Pillow