Merge pull request #379 from joshuavial/gpt4-vision

add support for gpt-4-vision
This commit is contained in:
paul-gauthier 2024-01-02 13:37:17 -04:00 committed by GitHub
commit 2824acb5c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 197 additions and 37 deletions

View file

@ -89,6 +89,7 @@ You can find more chat transcripts on the [examples page](https://aider.chat/exa
* You can use aider with multiple source files at once, so GPT can make coordinated code changes across all of them in a single changeset/commit.
* Aider can [give *GPT-4* a map of your entire git repo](https://aider.chat/docs/repomap.html), which helps it understand and modify large codebases.
* You can also edit files by hand using your editor while chatting with aider. Aider will notice these out-of-band edits and keep GPT up to date with the latest versions of your files. This lets you bounce back and forth between the aider chat and your editor, to collaboratively code with GPT.
* If you are using gpt-4 through openai directly, you can add image files to your context which will automatically switch you to the gpt-4-vision-preview model
## Usage

View file

@ -24,6 +24,7 @@ from aider.repo import GitRepo
from aider.repomap import RepoMap
from aider.sendchat import send_with_retries
from aider.utils import is_image_file
from ..dump import dump # noqa: F401
@ -298,6 +299,7 @@ class Coder:
prompt = ""
for fname, content in self.get_abs_fnames_content():
if not is_image_file(fname):
relative_fname = self.get_rel_fname(fname)
prompt += "\n"
prompt += relative_fname
@ -343,8 +345,36 @@ class Coder:
dict(role="assistant", content="Ok."),
]
images_message = self.get_images_message()
if images_message is not None:
files_messages.append(images_message)
return files_messages
def get_images_message(self):
if not utils.is_gpt4_with_openai_base_url(self.main_model.name, self.client):
return None
image_messages = []
for fname, content in self.get_abs_fnames_content():
if is_image_file(fname):
image_url = f"data:image/{Path(fname).suffix.lstrip('.')};base64,{content}"
image_messages.append({
"type": "image_url",
"image_url": {
"url": image_url,
"detail": "high"
}
})
if not image_messages:
return None
return {
"role": "user",
"content": image_messages
}
def run(self, with_message=None):
while True:
try:
@ -412,6 +442,7 @@ class Coder:
self.done_messages += self.cur_messages
self.summarize_start()
#TODO check for impact on image messages
if message:
self.done_messages += [
dict(role="user", content=message),
@ -458,6 +489,7 @@ class Coder:
dict(role="system", content=self.fmt_system_prompt(self.gpt_prompts.system_reminder)),
]
#TODO review impact of token count on image messages
messages_tokens = self.main_model.token_count(messages)
reminder_tokens = self.main_model.token_count(reminder_message)
cur_tokens = self.main_model.token_count(self.cur_messages)
@ -661,7 +693,7 @@ class Coder:
raise Exception("No data found in openai response!")
tokens = None
if hasattr(completion, "usage"):
if hasattr(completion, "usage") and completion.usage is not None:
prompt_tokens = completion.usage.prompt_tokens
completion_tokens = completion.usage.completion_tokens

View file

@ -10,6 +10,7 @@ from aider import prompts, voice
from .dump import dump # noqa: F401
from aider.utils import is_image_file, is_gpt4_with_openai_base_url
class Commands:
voice = None
@ -138,6 +139,9 @@ class Commands:
for fname in self.coder.abs_fnames:
relative_fname = self.coder.get_rel_fname(fname)
content = self.io.read_text(fname)
if is_image_file(relative_fname):
tokens = self.coder.main_model.token_count_for_image(fname)
else:
# approximate
content = f"{relative_fname}\n```\n" + content + "```\n"
tokens = self.coder.main_model.token_count(content)
@ -167,7 +171,12 @@ class Commands:
self.io.tool_output("=" * (width + cost_width + 1))
self.io.tool_output(f"${total_cost:5.2f} {fmt(total)} tokens total")
limit = self.coder.main_model.max_context_tokens
# only switch to image model token count if gpt4 and openai and image in files
image_in_chat = False
if is_gpt4_with_openai_base_url(self.coder.main_model.name, self.coder.client):
image_in_chat = any(is_image_file(relative_fname) for relative_fname in self.coder.get_inchat_relative_files())
limit = 128000 if image_in_chat else self.coder.main_model.max_context_tokens
remaining = limit - total
if remaining > 1024:
self.io.tool_output(f"{cost_pad}{fmt(remaining)} tokens remaining in context window")
@ -324,6 +333,9 @@ class Commands:
if abs_file_path in self.coder.abs_fnames:
self.io.tool_error(f"{matched_file} is already in the chat")
else:
if is_image_file(matched_file) and not is_gpt4_with_openai_base_url(self.coder.main_model.name, self.coder.client):
self.io.tool_error(f"Cannot add image file {matched_file} as the model does not support image files")
continue
content = self.io.read_text(abs_file_path)
if content is None:
self.io.tool_error(f"Unable to read {matched_file}")

View file

@ -1,6 +1,7 @@
import os
from collections import defaultdict
from datetime import datetime
import base64
from pathlib import Path
from prompt_toolkit.completion import Completer, Completion
@ -15,6 +16,7 @@ from pygments.util import ClassNotFound
from rich.console import Console
from rich.text import Text
from .utils import is_image_file
from .dump import dump # noqa: F401
@ -139,7 +141,26 @@ class InputOutput:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
def read_image(self, filename):
try:
with open(str(filename), "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode('utf-8')
except FileNotFoundError:
self.tool_error(f"{filename}: file not found error")
return
except IsADirectoryError:
self.tool_error(f"{filename}: is a directory")
return
except Exception as e:
self.tool_error(f"{filename}: {e}")
return
def read_text(self, filename):
if is_image_file(filename):
return self.read_image(filename)
try:
with open(str(filename), "r", encoding=self.encoding) as f:
return f.read()

View file

@ -1,5 +1,7 @@
import json
import math
from PIL import Image
class Model:
name = None
@ -49,3 +51,44 @@ class Model:
msgs = json.dumps(messages)
return len(self.tokenizer.encode(msgs))
def token_count_for_image(self, fname):
"""
Calculate the token cost for an image assuming high detail.
The token cost is determined by the size of the image.
:param fname: The filename of the image.
:return: The token cost for the image.
"""
width, height = self.get_image_size(fname)
# If the image is larger than 2048 in any dimension, scale it down to fit within 2048x2048
max_dimension = max(width, height)
if max_dimension > 2048:
scale_factor = 2048 / max_dimension
width = int(width * scale_factor)
height = int(height * scale_factor)
# Scale the image such that the shortest side is 768 pixels long
min_dimension = min(width, height)
scale_factor = 768 / min_dimension
width = int(width * scale_factor)
height = int(height * scale_factor)
# Calculate the number of 512x512 tiles needed to cover the image
tiles_width = math.ceil(width / 512)
tiles_height = math.ceil(height / 512)
num_tiles = tiles_width * tiles_height
# Each tile costs 170 tokens, and there's an additional fixed cost of 85 tokens
token_cost = num_tiles * 170 + 85
return token_cost
def get_image_size(self, fname):
"""
Retrieve the size of an image.
:param fname: The filename of the image.
:return: A tuple (width, height) representing the image size in pixels.
"""
with Image.open(fname) as img:
return img.size

View file

@ -21,19 +21,18 @@ class OpenRouterModel(Model):
if cached_model_details is None:
cached_model_details = client.models.list().data
found = next(
(details for details in cached_model_details if details.get("id") == name), None
(details for details in cached_model_details if details.id == name), None
)
if found:
self.max_context_tokens = int(found.get("context_length"))
self.prompt_price = round(float(found.get("pricing").get("prompt")) * 1000, 6)
self.completion_price = round(float(found.get("pricing").get("completion")) * 1000, 6)
self.max_context_tokens = int(found.context_length)
self.prompt_price = round(float(found.pricing.get("prompt")) * 1000, 6)
self.completion_price = round(float(found.pricing.get("completion")) * 1000, 6)
else:
raise ValueError(f"invalid openrouter model: {name}")
# TODO run benchmarks and figure out which models support which edit-formats
def edit_format_for_model(name):
if any(str in name for str in ["gpt-4", "claude-2"]):
return "diff"

View file

View file

@ -8,6 +8,7 @@ import openai
# from diskcache import Cache
from openai import APIConnectionError, InternalServerError, RateLimitError
from aider.utils import is_gpt4_with_openai_base_url
from aider.dump import dump # noqa: F401
CACHE_PATH = "~/.aider.send.cache.v1"
@ -41,6 +42,14 @@ def send_with_retries(client, model_name, messages, functions, stream):
if functions is not None:
kwargs["functions"] = functions
# Check conditions to switch to gpt-4-vision-preview or strip out image_url messages
if client and is_gpt4_with_openai_base_url(model_name, client):
if any(isinstance(msg.get("content"), list) and any("image_url" in item for item in msg.get("content") if isinstance(item, dict)) for msg in messages):
kwargs['model'] = "gpt-4-vision-preview"
# gpt-4-vision is limited to max tokens of 4096
kwargs["max_tokens"] = 4096
key = json.dumps(kwargs, sort_keys=True).encode()
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes

View file

@ -1,9 +1,10 @@
import os
import tempfile
from pathlib import Path
import git
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
from aider.dump import dump # noqa: F401
@ -64,6 +65,16 @@ def make_repo(path=None):
return repo
def is_image_file(file_name):
"""
Check if the given file name has an image file extension.
:param file_name: The name of the file to check.
:return: True if the file is an image, False otherwise.
"""
file_name = str(file_name) # Convert file_name to string
return any(file_name.endswith(ext) for ext in IMAGE_EXTENSIONS)
def safe_abs_path(res):
"Gives an abs path, which safely returns a full (not 8.3) windows path"
@ -78,7 +89,11 @@ def show_messages(messages, title=None, functions=None):
for msg in messages:
role = msg["role"].upper()
content = msg.get("content")
if content:
if isinstance(content, list): # Handle list content (e.g., image messages)
for item in content:
if isinstance(item, dict) and "image_url" in item:
print(role, "Image URL:", item["image_url"]["url"])
elif isinstance(content, str): # Handle string content
for line in content.splitlines():
print(role, line)
content = msg.get("function_call")
@ -87,3 +102,15 @@ def show_messages(messages, title=None, functions=None):
if functions:
dump(functions)
def is_gpt4_with_openai_base_url(model_name, client):
"""
Check if the model_name starts with 'gpt-4' and the client base URL includes 'api.openai.com'.
:param model_name: The name of the model to check.
:param client: The OpenAI client instance.
:return: True if conditions are met, False otherwise.
"""
if client is None or not hasattr(client, 'base_url'):
return False
return model_name.startswith("gpt-4") and "api.openai.com" in client.base_url.host

View file

@ -787,7 +787,22 @@ def run_test(
chat_history_file=history_fname,
)
main_model = models.Model.create(model_name)
if "OPENAI_API_BASE" in os.environ and "openrouter.ai" in os.environ["OPENAI_API_BASE"]:
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ.get("OPENAI_API_BASE"),
default_headers={
"HTTP-Referer": "http://aider.chat",
"X-Title": "Aider",
}
)
else:
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
)
main_model = models.Model.create(model_name, client)
edit_format = edit_format or main_model.edit_format
dump(main_model)
@ -795,8 +810,6 @@ def run_test(
show_fnames = ",".join(map(str, fnames))
print("fnames:", show_fnames)
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
coder = Coder.create(
main_model,
edit_format,

View file

@ -19,4 +19,5 @@ packaging
sounddevice
soundfile
PyYAML
Pillow
diff-match-patch

View file

@ -74,6 +74,8 @@ pathspec==0.11.2
# via
# -r requirements.in
# grep-ast
pillow==10.1.0
# via -r requirements.in
prompt-toolkit==3.0.41
# via -r requirements.in
pycparser==2.21

View file

@ -29,19 +29,19 @@ class TestModels(unittest.TestCase):
def test_openrouter_model_properties(self):
client = MagicMock()
client.models.list.return_value = {
"data": [
{
"id": "openai/gpt-4",
"object": "model",
"context_length": "8192",
"pricing": {"prompt": "0.00006", "completion": "0.00012"},
}
]
}
client.models.list.return_value = type(
"", (), {"data": client.models.list.return_value["data"]}
)()
class ModelData:
def __init__(self, id, object, context_length, pricing):
self.id = id
self.object = object
self.context_length = context_length
self.pricing = pricing
model_data = ModelData("openai/gpt-4", "model", "8192", {"prompt": "0.00006", "completion": "0.00012"})
class ModelList:
def __init__(self, data):
self.data = data
client.models.list.return_value = ModelList([model_data])
model = OpenRouterModel(client, "gpt-4")
self.assertEqual(model.name, "openai/gpt-4")