Added --weak-model

This commit is contained in:
Paul Gauthier 2024-04-19 10:57:21 -07:00
parent 93bd187bf3
commit 7ec4de865d
4 changed files with 44 additions and 20 deletions

View file

@ -136,7 +136,11 @@ class Coder:
self.main_model = main_model self.main_model = main_model
self.io.tool_output(f"Model: {main_model.name} using {self.edit_format} edit format") weak_model = main_model.weak_model
self.io.tool_output(
f"Models: {main_model.name} with {self.edit_format} edit format, weak model"
f" {weak_model.name}"
)
self.show_diffs = show_diffs self.show_diffs = show_diffs
@ -212,7 +216,7 @@ class Coder:
self.io.tool_output(f"Added {fname} to the chat.") self.io.tool_output(f"Added {fname} to the chat.")
self.summarizer = ChatSummary( self.summarizer = ChatSummary(
self.main_model.get_weak_model(), self.main_model.weak_model,
self.main_model.max_chat_history_tokens, self.main_model.max_chat_history_tokens,
) )

View file

@ -124,7 +124,7 @@ def main():
assistant.append(line) assistant.append(line)
summarizer = ChatSummary(models.Model.get_weak_model()) summarizer = ChatSummary(models.Model(models.DEFAULT_WEAK_MODEL_NAME))
summary = summarizer.summarize(messages[-40:]) summary = summarizer.summarize(messages[-40:])
dump(summary) dump(summary)

View file

@ -196,7 +196,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
const=default_4_turbo_model, const=default_4_turbo_model,
help=f"Use {default_4_turbo_model} model for the main chat", help=f"Use {default_4_turbo_model} model for the main chat",
) )
default_3_model_name = "gpt-3.5-turbo-0125" default_3_model_name = "gpt-3.5-turbo"
core_group.add_argument( core_group.add_argument(
"--35turbo", "--35turbo",
"--35-turbo", "--35-turbo",
@ -252,6 +252,15 @@ def main(argv=None, input=None, output=None, force_git_root=None):
default=None, default=None,
help="Specify what edit format GPT should use (default depends on model)", help="Specify what edit format GPT should use (default depends on model)",
) )
core_group.add_argument(
"--weak-model",
metavar="WEAK_MODEL",
default=None,
help=(
"Specify the model to use for commit messages and chat history summarization (default"
" depends on --model)"
),
)
model_group.add_argument( model_group.add_argument(
"--map-tokens", "--map-tokens",
type=int, type=int,
@ -583,10 +592,9 @@ def main(argv=None, input=None, output=None, force_git_root=None):
# Check in advance that we have model metadata # Check in advance that we have model metadata
try: try:
main_model = models.Model(args.model) main_model = models.Model(args.model, weak_model=args.weak_model)
except models.NoModelInfo as err: except models.NoModelInfo as err:
io.tool_error(f"Unknown model {args.model}.") io.tool_error(f"Unknown model {err}.")
io.tool_error(str(err))
return 1 return 1
try: try:

View file

@ -1,7 +1,6 @@
import json import json
import math import math
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Optional
import litellm import litellm
from PIL import Image from PIL import Image
@ -9,6 +8,7 @@ from PIL import Image
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
DEFAULT_MODEL_NAME = "gpt-4-1106-preview" DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
DEFAULT_WEAK_MODEL_NAME = "gpt-3.5-turbo"
class NoModelInfo(Exception): class NoModelInfo(Exception):
@ -16,15 +16,15 @@ class NoModelInfo(Exception):
Exception raised when model information cannot be retrieved. Exception raised when model information cannot be retrieved.
""" """
def __init__(self, message: Optional[str] = None): def __init__(self, model):
super().__init__(message or "No model information available.") super().__init__(model)
@dataclass @dataclass
class ModelSettings: class ModelSettings:
name: str name: str
edit_format: str edit_format: str
weak_model_name: str = "gpt-3.5-turbo-0125" weak_model_name: str = DEFAULT_WEAK_MODEL_NAME
use_repo_map: bool = False use_repo_map: bool = False
send_undo_reply: bool = False send_undo_reply: bool = False
accepts_images: bool = False accepts_images: bool = False
@ -112,23 +112,22 @@ MODEL_SETTINGS = [
class Model: class Model:
name = None name = None
weak_model_name = "gpt-3.5-turbo-0125"
edit_format = "whole" edit_format = "whole"
use_repo_map = False use_repo_map = False
send_undo_reply = False send_undo_reply = False
accepts_images = False accepts_images = False
weak_model_name = DEFAULT_WEAK_MODEL_NAME
max_chat_history_tokens = 1024 max_chat_history_tokens = 1024
weak_model = None weak_model = None
def __init__(self, model): def __init__(self, model, weak_model=None):
self.name = model self.name = model
try: try:
self.info = litellm.get_model_info(model) self.info = litellm.get_model_info(model)
except Exception as err: except Exception:
raise NoModelInfo(str(err)) raise NoModelInfo(model)
if self.info.get("max_input_tokens", 0) < 32 * 1024: if self.info.get("max_input_tokens", 0) < 32 * 1024:
self.max_chat_history_tokens = 1024 self.max_chat_history_tokens = 1024
@ -136,6 +135,7 @@ class Model:
self.max_chat_history_tokens = 2 * 1024 self.max_chat_history_tokens = 2 * 1024
self.configure_model_settings(model) self.configure_model_settings(model)
self.get_weak_model(weak_model)
def configure_model_settings(self, model): def configure_model_settings(self, model):
for ms in MODEL_SETTINGS: for ms in MODEL_SETTINGS:
@ -159,13 +159,25 @@ class Model:
def __str__(self): def __str__(self):
return self.name return self.name
def get_weak_model(self): def get_weak_model(self, provided_weak_model_name):
if not self.weak_model: # If weak_model_name is provided, override the model settings
self.weak_model = Model(self.weak_model_name) if provided_weak_model_name:
self.weak_model_name = provided_weak_model_name
if self.weak_model_name == self.name:
self.weak_model = self
return
try:
self.info = litellm.get_model_info(self.weak_model_name)
except Exception:
raise NoModelInfo(self.weak_model_name)
self.weak_model = Model(self.weak_model_name)
return self.weak_model return self.weak_model
def commit_message_models(self): def commit_message_models(self):
return [self.get_weak_model()] return [self.weak_model]
def tokenizer(self, text): def tokenizer(self, text):
return litellm.encode(model=self.name, text=text) return litellm.encode(model=self.name, text=text)