mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
Merge pull request #379 from joshuavial/gpt4-vision
add support for gpt-4-vision
This commit is contained in:
commit
2824acb5c2
13 changed files with 197 additions and 37 deletions
|
@ -89,6 +89,7 @@ You can find more chat transcripts on the [examples page](https://aider.chat/exa
|
|||
* You can use aider with multiple source files at once, so GPT can make coordinated code changes across all of them in a single changeset/commit.
|
||||
* Aider can [give *GPT-4* a map of your entire git repo](https://aider.chat/docs/repomap.html), which helps it understand and modify large codebases.
|
||||
* You can also edit files by hand using your editor while chatting with aider. Aider will notice these out-of-band edits and keep GPT up to date with the latest versions of your files. This lets you bounce back and forth between the aider chat and your editor, to collaboratively code with GPT.
|
||||
* If you are using gpt-4 through openai directly, you can add image files to your context which will automatically switch you to the gpt-4-vision-preview model
|
||||
|
||||
|
||||
## Usage
|
||||
|
|
|
@ -24,6 +24,7 @@ from aider.repo import GitRepo
|
|||
from aider.repomap import RepoMap
|
||||
from aider.sendchat import send_with_retries
|
||||
|
||||
from aider.utils import is_image_file
|
||||
from ..dump import dump # noqa: F401
|
||||
|
||||
|
||||
|
@ -298,6 +299,7 @@ class Coder:
|
|||
|
||||
prompt = ""
|
||||
for fname, content in self.get_abs_fnames_content():
|
||||
if not is_image_file(fname):
|
||||
relative_fname = self.get_rel_fname(fname)
|
||||
prompt += "\n"
|
||||
prompt += relative_fname
|
||||
|
@ -343,8 +345,36 @@ class Coder:
|
|||
dict(role="assistant", content="Ok."),
|
||||
]
|
||||
|
||||
images_message = self.get_images_message()
|
||||
if images_message is not None:
|
||||
files_messages.append(images_message)
|
||||
|
||||
return files_messages
|
||||
|
||||
def get_images_message(self):
|
||||
if not utils.is_gpt4_with_openai_base_url(self.main_model.name, self.client):
|
||||
return None
|
||||
|
||||
image_messages = []
|
||||
for fname, content in self.get_abs_fnames_content():
|
||||
if is_image_file(fname):
|
||||
image_url = f"data:image/{Path(fname).suffix.lstrip('.')};base64,{content}"
|
||||
image_messages.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
"detail": "high"
|
||||
}
|
||||
})
|
||||
|
||||
if not image_messages:
|
||||
return None
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": image_messages
|
||||
}
|
||||
|
||||
def run(self, with_message=None):
|
||||
while True:
|
||||
try:
|
||||
|
@ -412,6 +442,7 @@ class Coder:
|
|||
self.done_messages += self.cur_messages
|
||||
self.summarize_start()
|
||||
|
||||
#TODO check for impact on image messages
|
||||
if message:
|
||||
self.done_messages += [
|
||||
dict(role="user", content=message),
|
||||
|
@ -458,6 +489,7 @@ class Coder:
|
|||
dict(role="system", content=self.fmt_system_prompt(self.gpt_prompts.system_reminder)),
|
||||
]
|
||||
|
||||
#TODO review impact of token count on image messages
|
||||
messages_tokens = self.main_model.token_count(messages)
|
||||
reminder_tokens = self.main_model.token_count(reminder_message)
|
||||
cur_tokens = self.main_model.token_count(self.cur_messages)
|
||||
|
@ -661,7 +693,7 @@ class Coder:
|
|||
raise Exception("No data found in openai response!")
|
||||
|
||||
tokens = None
|
||||
if hasattr(completion, "usage"):
|
||||
if hasattr(completion, "usage") and completion.usage is not None:
|
||||
prompt_tokens = completion.usage.prompt_tokens
|
||||
completion_tokens = completion.usage.completion_tokens
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from aider import prompts, voice
|
|||
|
||||
from .dump import dump # noqa: F401
|
||||
|
||||
from aider.utils import is_image_file, is_gpt4_with_openai_base_url
|
||||
|
||||
class Commands:
|
||||
voice = None
|
||||
|
@ -138,6 +139,9 @@ class Commands:
|
|||
for fname in self.coder.abs_fnames:
|
||||
relative_fname = self.coder.get_rel_fname(fname)
|
||||
content = self.io.read_text(fname)
|
||||
if is_image_file(relative_fname):
|
||||
tokens = self.coder.main_model.token_count_for_image(fname)
|
||||
else:
|
||||
# approximate
|
||||
content = f"{relative_fname}\n```\n" + content + "```\n"
|
||||
tokens = self.coder.main_model.token_count(content)
|
||||
|
@ -167,7 +171,12 @@ class Commands:
|
|||
self.io.tool_output("=" * (width + cost_width + 1))
|
||||
self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total")
|
||||
|
||||
limit = self.coder.main_model.max_context_tokens
|
||||
# only switch to image model token count if gpt4 and openai and image in files
|
||||
image_in_chat = False
|
||||
if is_gpt4_with_openai_base_url(self.coder.main_model.name, self.coder.client):
|
||||
image_in_chat = any(is_image_file(relative_fname) for relative_fname in self.coder.get_inchat_relative_files())
|
||||
limit = 128000 if image_in_chat else self.coder.main_model.max_context_tokens
|
||||
|
||||
remaining = limit - total
|
||||
if remaining > 1024:
|
||||
self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window")
|
||||
|
@ -324,6 +333,9 @@ class Commands:
|
|||
if abs_file_path in self.coder.abs_fnames:
|
||||
self.io.tool_error(f"{matched_file} is already in the chat")
|
||||
else:
|
||||
if is_image_file(matched_file) and not is_gpt4_with_openai_base_url(self.coder.main_model.name, self.coder.client):
|
||||
self.io.tool_error(f"Cannot add image file {matched_file} as the model does not support image files")
|
||||
continue
|
||||
content = self.io.read_text(abs_file_path)
|
||||
if content is None:
|
||||
self.io.tool_error(f"Unable to read {matched_file}")
|
||||
|
|
21
aider/io.py
21
aider/io.py
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
@ -15,6 +16,7 @@ from pygments.util import ClassNotFound
|
|||
from rich.console import Console
|
||||
from rich.text import Text
|
||||
|
||||
from .utils import is_image_file
|
||||
from .dump import dump # noqa: F401
|
||||
|
||||
|
||||
|
@ -139,7 +141,26 @@ class InputOutput:
|
|||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
|
||||
|
||||
|
||||
def read_image(self, filename):
|
||||
try:
|
||||
with open(str(filename), "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return encoded_string.decode('utf-8')
|
||||
except FileNotFoundError:
|
||||
self.tool_error(f"{filename}: file not found error")
|
||||
return
|
||||
except IsADirectoryError:
|
||||
self.tool_error(f"{filename}: is a directory")
|
||||
return
|
||||
except Exception as e:
|
||||
self.tool_error(f"{filename}: {e}")
|
||||
return
|
||||
|
||||
def read_text(self, filename):
|
||||
if is_image_file(filename):
|
||||
return self.read_image(filename)
|
||||
|
||||
try:
|
||||
with open(str(filename), "r", encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
|
||||
class Model:
|
||||
name = None
|
||||
|
@ -49,3 +51,44 @@ class Model:
|
|||
msgs = json.dumps(messages)
|
||||
|
||||
return len(self.tokenizer.encode(msgs))
|
||||
|
||||
def token_count_for_image(self, fname):
|
||||
"""
|
||||
Calculate the token cost for an image assuming high detail.
|
||||
The token cost is determined by the size of the image.
|
||||
:param fname: The filename of the image.
|
||||
:return: The token cost for the image.
|
||||
"""
|
||||
width, height = self.get_image_size(fname)
|
||||
|
||||
# If the image is larger than 2048 in any dimension, scale it down to fit within 2048x2048
|
||||
max_dimension = max(width, height)
|
||||
if max_dimension > 2048:
|
||||
scale_factor = 2048 / max_dimension
|
||||
width = int(width * scale_factor)
|
||||
height = int(height * scale_factor)
|
||||
|
||||
# Scale the image such that the shortest side is 768 pixels long
|
||||
min_dimension = min(width, height)
|
||||
scale_factor = 768 / min_dimension
|
||||
width = int(width * scale_factor)
|
||||
height = int(height * scale_factor)
|
||||
|
||||
# Calculate the number of 512x512 tiles needed to cover the image
|
||||
tiles_width = math.ceil(width / 512)
|
||||
tiles_height = math.ceil(height / 512)
|
||||
num_tiles = tiles_width * tiles_height
|
||||
|
||||
# Each tile costs 170 tokens, and there's an additional fixed cost of 85 tokens
|
||||
token_cost = num_tiles * 170 + 85
|
||||
return token_cost
|
||||
|
||||
|
||||
def get_image_size(self, fname):
|
||||
"""
|
||||
Retrieve the size of an image.
|
||||
:param fname: The filename of the image.
|
||||
:return: A tuple (width, height) representing the image size in pixels.
|
||||
"""
|
||||
with Image.open(fname) as img:
|
||||
return img.size
|
||||
|
|
|
@ -21,19 +21,18 @@ class OpenRouterModel(Model):
|
|||
if cached_model_details is None:
|
||||
cached_model_details = client.models.list().data
|
||||
found = next(
|
||||
(details for details in cached_model_details if details.get("id") == name), None
|
||||
(details for details in cached_model_details if details.id == name), None
|
||||
)
|
||||
|
||||
if found:
|
||||
self.max_context_tokens = int(found.get("context_length"))
|
||||
self.prompt_price = round(float(found.get("pricing").get("prompt")) * 1000, 6)
|
||||
self.completion_price = round(float(found.get("pricing").get("completion")) * 1000, 6)
|
||||
self.max_context_tokens = int(found.context_length)
|
||||
self.prompt_price = round(float(found.pricing.get("prompt")) * 1000, 6)
|
||||
self.completion_price = round(float(found.pricing.get("completion")) * 1000, 6)
|
||||
|
||||
else:
|
||||
raise ValueError(f"invalid openrouter model: {name}")
|
||||
|
||||
|
||||
# TODO run benchmarks and figure out which models support which edit-formats
|
||||
def edit_format_for_model(name):
|
||||
if any(str in name for str in ["gpt-4", "claude-2"]):
|
||||
return "diff"
|
||||
|
|
0
aider/retrievers/fleet.py
Normal file
0
aider/retrievers/fleet.py
Normal file
|
@ -8,6 +8,7 @@ import openai
|
|||
# from diskcache import Cache
|
||||
from openai import APIConnectionError, InternalServerError, RateLimitError
|
||||
|
||||
from aider.utils import is_gpt4_with_openai_base_url
|
||||
from aider.dump import dump # noqa: F401
|
||||
|
||||
CACHE_PATH = "~/.aider.send.cache.v1"
|
||||
|
@ -41,6 +42,14 @@ def send_with_retries(client, model_name, messages, functions, stream):
|
|||
if functions is not None:
|
||||
kwargs["functions"] = functions
|
||||
|
||||
|
||||
# Check conditions to switch to gpt-4-vision-preview or strip out image_url messages
|
||||
if client and is_gpt4_with_openai_base_url(model_name, client):
|
||||
if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages):
|
||||
kwargs['model'] = "gpt-4-vision-preview"
|
||||
# gpt-4-vision is limited to max tokens of 4096
|
||||
kwargs["max_tokens"] = 4096
|
||||
|
||||
key = json.dumps(kwargs, sort_keys=True).encode()
|
||||
|
||||
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import git
|
||||
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
|
||||
from aider.dump import dump # noqa: F401
|
||||
|
||||
|
||||
|
@ -64,6 +65,16 @@ def make_repo(path=None):
|
|||
|
||||
return repo
|
||||
|
||||
def is_image_file(file_name):
|
||||
"""
|
||||
Check if the given file name has an image file extension.
|
||||
|
||||
:param file_name: The name of the file to check.
|
||||
:return: True if the file is an image, False otherwise.
|
||||
"""
|
||||
file_name = str(file_name) # Convert file_name to string
|
||||
return any(file_name.endswith(ext) for ext in IMAGE_EXTENSIONS)
|
||||
|
||||
|
||||
def safe_abs_path(res):
|
||||
"Gives an abs path, which safely returns a full (not 8.3) windows path"
|
||||
|
@ -78,7 +89,11 @@ def show_messages(messages, title=None, functions=None):
|
|||
for msg in messages:
|
||||
role = msg["role"].upper()
|
||||
content = msg.get("content")
|
||||
if content:
|
||||
if isinstance(content, list): # Handle list content (e.g., image messages)
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "image_url" in item:
|
||||
print(role, "Image URL:", item["image_url"]["url"])
|
||||
elif isinstance(content, str): # Handle string content
|
||||
for line in content.splitlines():
|
||||
print(role, line)
|
||||
content = msg.get("function_call")
|
||||
|
@ -87,3 +102,15 @@ def show_messages(messages, title=None, functions=None):
|
|||
|
||||
if functions:
|
||||
dump(functions)
|
||||
|
||||
def is_gpt4_with_openai_base_url(model_name, client):
|
||||
"""
|
||||
Check if the model_name starts with 'gpt-4' and the client base URL includes 'api.openai.com'.
|
||||
|
||||
:param model_name: The name of the model to check.
|
||||
:param client: The OpenAI client instance.
|
||||
:return: True if conditions are met, False otherwise.
|
||||
"""
|
||||
if client is None or not hasattr(client, 'base_url'):
|
||||
return False
|
||||
return model_name.startswith("gpt-4") and "api.openai.com" in client.base_url.host
|
||||
|
|
|
@ -787,7 +787,22 @@ def run_test(
|
|||
chat_history_file=history_fname,
|
||||
)
|
||||
|
||||
main_model = models.Model.create(model_name)
|
||||
if "OPENAI_API_BASE" in os.environ and "openrouter.ai" in os.environ["OPENAI_API_BASE"]:
|
||||
client = openai.OpenAI(
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ.get("OPENAI_API_BASE"),
|
||||
default_headers={
|
||||
"HTTP-Referer": "http://aider.chat",
|
||||
"X-Title": "Aider",
|
||||
}
|
||||
)
|
||||
else:
|
||||
client = openai.OpenAI(
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
)
|
||||
|
||||
main_model = models.Model.create(model_name, client)
|
||||
edit_format = edit_format or main_model.edit_format
|
||||
|
||||
dump(main_model)
|
||||
|
@ -795,8 +810,6 @@ def run_test(
|
|||
show_fnames = ",".join(map(str, fnames))
|
||||
print("fnames:", show_fnames)
|
||||
|
||||
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||
|
||||
coder = Coder.create(
|
||||
main_model,
|
||||
edit_format,
|
||||
|
|
|
@ -19,4 +19,5 @@ packaging
|
|||
sounddevice
|
||||
soundfile
|
||||
PyYAML
|
||||
Pillow
|
||||
diff-match-patch
|
||||
|
|
|
@ -74,6 +74,8 @@ pathspec==0.11.2
|
|||
# via
|
||||
# -r requirements.in
|
||||
# grep-ast
|
||||
pillow==10.1.0
|
||||
# via -r requirements.in
|
||||
prompt-toolkit==3.0.41
|
||||
# via -r requirements.in
|
||||
pycparser==2.21
|
||||
|
|
|
@ -29,19 +29,19 @@ class TestModels(unittest.TestCase):
|
|||
|
||||
def test_openrouter_model_properties(self):
|
||||
client = MagicMock()
|
||||
client.models.list.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": "openai/gpt-4",
|
||||
"object": "model",
|
||||
"context_length": "8192",
|
||||
"pricing": {"prompt": "0.00006", "completion": "0.00012"},
|
||||
}
|
||||
]
|
||||
}
|
||||
client.models.list.return_value = type(
|
||||
"", (), {"data": client.models.list.return_value["data"]}
|
||||
)()
|
||||
class ModelData:
|
||||
def __init__(self, id, object, context_length, pricing):
|
||||
self.id = id
|
||||
self.object = object
|
||||
self.context_length = context_length
|
||||
self.pricing = pricing
|
||||
|
||||
model_data = ModelData("openai/gpt-4", "model", "8192", {"prompt": "0.00006", "completion": "0.00012"})
|
||||
class ModelList:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
client.models.list.return_value = ModelList([model_data])
|
||||
|
||||
model = OpenRouterModel(client, "gpt-4")
|
||||
self.assertEqual(model.name, "openai/gpt-4")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue