Don't assume gpt-3.5-turbo as the weak model, use main_model if unsure

This commit is contained in:
Paul Gauthier 2024-04-22 10:55:59 -07:00
parent 83c76a3c65
commit 30c095349a
2 changed files with 38 additions and 10 deletions

View file

@ -1,6 +1,6 @@
import argparse import argparse
from aider import models, prompts from aider import prompts
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.sendchat import simple_send_with_retries from aider.sendchat import simple_send_with_retries
@ -123,7 +123,7 @@ def main():
assistant.append(line) assistant.append(line)
summarizer = ChatSummary(models.Model(models.DEFAULT_WEAK_MODEL_NAME, weak_model=False)) summarizer = ChatSummary("gpt-3.5-turbo", weak_model=False)
summary = summarizer.summarize(messages[-40:]) summary = summarizer.summarize(messages[-40:])
dump(summary) dump(summary)

View file

@ -3,6 +3,7 @@ import json
import math import math
import sys import sys
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Optional
import litellm import litellm
from PIL import Image from PIL import Image
@ -10,7 +11,6 @@ from PIL import Image
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
DEFAULT_MODEL_NAME = "gpt-4-1106-preview" DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
DEFAULT_WEAK_MODEL_NAME = "gpt-3.5-turbo"
class NoModelInfo(Exception): class NoModelInfo(Exception):
@ -35,7 +35,7 @@ class ModelEnvironmentError(Exception):
class ModelSettings: class ModelSettings:
name: str name: str
edit_format: str edit_format: str
weak_model_name: str = DEFAULT_WEAK_MODEL_NAME weak_model_name: Optional[str] = None
use_repo_map: bool = False use_repo_map: bool = False
send_undo_reply: bool = False send_undo_reply: bool = False
accepts_images: bool = False accepts_images: bool = False
@ -50,23 +50,28 @@ MODEL_SETTINGS = [
ModelSettings( ModelSettings(
"gpt-3.5-turbo-0125", "gpt-3.5-turbo-0125",
"whole", "whole",
weak_model_name="gpt-3.5-turbo",
), ),
ModelSettings( ModelSettings(
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-1106",
"whole", "whole",
weak_model_name="gpt-3.5-turbo",
), ),
ModelSettings( ModelSettings(
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613",
"whole", "whole",
weak_model_name="gpt-3.5-turbo",
), ),
ModelSettings( ModelSettings(
"gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k-0613",
"whole", "whole",
weak_model_name="gpt-3.5-turbo",
), ),
# gpt-4 # gpt-4
ModelSettings( ModelSettings(
"gpt-4-turbo-2024-04-09", "gpt-4-turbo-2024-04-09",
"udiff", "udiff",
weak_model_name="gpt-3.5-turbo",
use_repo_map=True, use_repo_map=True,
send_undo_reply=True, send_undo_reply=True,
accepts_images=True, accepts_images=True,
@ -74,6 +79,7 @@ MODEL_SETTINGS = [
ModelSettings( ModelSettings(
"gpt-4-turbo", "gpt-4-turbo",
"udiff", "udiff",
weak_model_name="gpt-3.5-turbo",
use_repo_map=True, use_repo_map=True,
send_undo_reply=True, send_undo_reply=True,
accepts_images=True, accepts_images=True,
@ -81,18 +87,21 @@ MODEL_SETTINGS = [
ModelSettings( ModelSettings(
"gpt-4-0125-preview", "gpt-4-0125-preview",
"udiff", "udiff",
weak_model_name="gpt-3.5-turbo",
use_repo_map=True, use_repo_map=True,
send_undo_reply=True, send_undo_reply=True,
), ),
ModelSettings( ModelSettings(
"gpt-4-1106-preview", "gpt-4-1106-preview",
"udiff", "udiff",
weak_model_name="gpt-3.5-turbo",
use_repo_map=True, use_repo_map=True,
send_undo_reply=True, send_undo_reply=True,
), ),
ModelSettings( ModelSettings(
"gpt-4-vision-preview", "gpt-4-vision-preview",
"diff", "diff",
weak_model_name="gpt-3.5-turbo",
use_repo_map=True, use_repo_map=True,
send_undo_reply=True, send_undo_reply=True,
accepts_images=True, accepts_images=True,
@ -100,12 +109,14 @@ MODEL_SETTINGS = [
ModelSettings( ModelSettings(
"gpt-4-0613", "gpt-4-0613",
"diff", "diff",
weak_model_name="gpt-3.5-turbo",
use_repo_map=True, use_repo_map=True,
send_undo_reply=True, send_undo_reply=True,
), ),
ModelSettings( ModelSettings(
"gpt-4-32k-0613", "gpt-4-32k-0613",
"diff", "diff",
weak_model_name="gpt-3.5-turbo",
use_repo_map=True, use_repo_map=True,
send_undo_reply=True, send_undo_reply=True,
), ),
@ -117,6 +128,11 @@ MODEL_SETTINGS = [
use_repo_map=True, use_repo_map=True,
send_undo_reply=True, send_undo_reply=True,
), ),
ModelSettings(
"claude-3-sonnet-20240229",
"whole",
weak_model_name="claude-3-haiku-20240307",
),
# Cohere # Cohere
ModelSettings( ModelSettings(
"command-r-plus", "command-r-plus",
@ -143,7 +159,7 @@ class Model:
use_repo_map = False use_repo_map = False
send_undo_reply = False send_undo_reply = False
accepts_images = False accepts_images = False
weak_model_name = DEFAULT_WEAK_MODEL_NAME weak_model_name = None
max_chat_history_tokens = 1024 max_chat_history_tokens = 1024
weak_model = None weak_model = None
@ -188,20 +204,28 @@ class Model:
def configure_model_settings(self, model): def configure_model_settings(self, model):
for ms in MODEL_SETTINGS: for ms in MODEL_SETTINGS:
# direct match, or match "provider/<model>" # direct match, or match "provider/<model>"
if model == ms.name or model.endswith("/" + ms.name): if model == ms.name:
for field in fields(ModelSettings): for field in fields(ModelSettings):
if field.name == "name":
continue
val = getattr(ms, field.name) val = getattr(ms, field.name)
setattr(self, field.name, val) setattr(self, field.name, val)
return # <-- return # <--
if "gpt-4" in model or "claude-2" in model: if "llama3" in model and "70b" in model:
self.edit_format = "diff" self.edit_format = "diff"
self.use_repo_map = True self.use_repo_map = True
self.send_undo_reply = True self.send_undo_reply = True
return # <--
if "gpt-4-turbo" in model or ("gpt-4-" in model and "-preview" in model):
self.edit_format = "udiff"
self.use_repo_map = True
self.send_undo_reply = True
return # <--
if "gpt-4" in model or "claude-2" in model or "claude-3-opus" in model:
self.edit_format = "diff"
self.use_repo_map = True
self.send_undo_reply = True
return # <-- return # <--
# use the defaults # use the defaults
@ -216,6 +240,10 @@ class Model:
if provided_weak_model_name: if provided_weak_model_name:
self.weak_model_name = provided_weak_model_name self.weak_model_name = provided_weak_model_name
if not self.weak_model_name:
self.weak_model = self
return
if self.weak_model_name == self.name: if self.weak_model_name == self.name:
self.weak_model = self self.weak_model = self
return return