mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-25 14:55:00 +00:00
Added --weak-model
This commit is contained in:
parent
93bd187bf3
commit
7ec4de865d
4 changed files with 44 additions and 20 deletions
|
@ -136,7 +136,11 @@ class Coder:
|
||||||
|
|
||||||
self.main_model = main_model
|
self.main_model = main_model
|
||||||
|
|
||||||
self.io.tool_output(f"Model: {main_model.name} using {self.edit_format} edit format")
|
weak_model = main_model.weak_model
|
||||||
|
self.io.tool_output(
|
||||||
|
f"Models: {main_model.name} with {self.edit_format} edit format, weak model"
|
||||||
|
f" {weak_model.name}"
|
||||||
|
)
|
||||||
|
|
||||||
self.show_diffs = show_diffs
|
self.show_diffs = show_diffs
|
||||||
|
|
||||||
|
@ -212,7 +216,7 @@ class Coder:
|
||||||
self.io.tool_output(f"Added {fname} to the chat.")
|
self.io.tool_output(f"Added {fname} to the chat.")
|
||||||
|
|
||||||
self.summarizer = ChatSummary(
|
self.summarizer = ChatSummary(
|
||||||
self.main_model.get_weak_model(),
|
self.main_model.weak_model,
|
||||||
self.main_model.max_chat_history_tokens,
|
self.main_model.max_chat_history_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -124,7 +124,7 @@ def main():
|
||||||
|
|
||||||
assistant.append(line)
|
assistant.append(line)
|
||||||
|
|
||||||
summarizer = ChatSummary(models.Model.get_weak_model())
|
summarizer = ChatSummary(models.Model(models.DEFAULT_WEAK_MODEL_NAME))
|
||||||
summary = summarizer.summarize(messages[-40:])
|
summary = summarizer.summarize(messages[-40:])
|
||||||
dump(summary)
|
dump(summary)
|
||||||
|
|
||||||
|
|
|
@ -196,7 +196,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
const=default_4_turbo_model,
|
const=default_4_turbo_model,
|
||||||
help=f"Use {default_4_turbo_model} model for the main chat",
|
help=f"Use {default_4_turbo_model} model for the main chat",
|
||||||
)
|
)
|
||||||
default_3_model_name = "gpt-3.5-turbo-0125"
|
default_3_model_name = "gpt-3.5-turbo"
|
||||||
core_group.add_argument(
|
core_group.add_argument(
|
||||||
"--35turbo",
|
"--35turbo",
|
||||||
"--35-turbo",
|
"--35-turbo",
|
||||||
|
@ -252,6 +252,15 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify what edit format GPT should use (default depends on model)",
|
help="Specify what edit format GPT should use (default depends on model)",
|
||||||
)
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--weak-model",
|
||||||
|
metavar="WEAK_MODEL",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Specify the model to use for commit messages and chat history summarization (default"
|
||||||
|
" depends on --model)"
|
||||||
|
),
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--map-tokens",
|
"--map-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -583,10 +592,9 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
|
|
||||||
# Check in advance that we have model metadata
|
# Check in advance that we have model metadata
|
||||||
try:
|
try:
|
||||||
main_model = models.Model(args.model)
|
main_model = models.Model(args.model, weak_model=args.weak_model)
|
||||||
except models.NoModelInfo as err:
|
except models.NoModelInfo as err:
|
||||||
io.tool_error(f"Unknown model {args.model}.")
|
io.tool_error(f"Unknown model {err}.")
|
||||||
io.tool_error(str(err))
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -9,6 +8,7 @@ from PIL import Image
|
||||||
from aider.dump import dump # noqa: F401
|
from aider.dump import dump # noqa: F401
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
|
DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
|
||||||
|
DEFAULT_WEAK_MODEL_NAME = "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
||||||
class NoModelInfo(Exception):
|
class NoModelInfo(Exception):
|
||||||
|
@ -16,15 +16,15 @@ class NoModelInfo(Exception):
|
||||||
Exception raised when model information cannot be retrieved.
|
Exception raised when model information cannot be retrieved.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, message: Optional[str] = None):
|
def __init__(self, model):
|
||||||
super().__init__(message or "No model information available.")
|
super().__init__(model)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelSettings:
|
class ModelSettings:
|
||||||
name: str
|
name: str
|
||||||
edit_format: str
|
edit_format: str
|
||||||
weak_model_name: str = "gpt-3.5-turbo-0125"
|
weak_model_name: str = DEFAULT_WEAK_MODEL_NAME
|
||||||
use_repo_map: bool = False
|
use_repo_map: bool = False
|
||||||
send_undo_reply: bool = False
|
send_undo_reply: bool = False
|
||||||
accepts_images: bool = False
|
accepts_images: bool = False
|
||||||
|
@ -112,23 +112,22 @@ MODEL_SETTINGS = [
|
||||||
class Model:
|
class Model:
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
weak_model_name = "gpt-3.5-turbo-0125"
|
|
||||||
|
|
||||||
edit_format = "whole"
|
edit_format = "whole"
|
||||||
use_repo_map = False
|
use_repo_map = False
|
||||||
send_undo_reply = False
|
send_undo_reply = False
|
||||||
accepts_images = False
|
accepts_images = False
|
||||||
|
weak_model_name = DEFAULT_WEAK_MODEL_NAME
|
||||||
|
|
||||||
max_chat_history_tokens = 1024
|
max_chat_history_tokens = 1024
|
||||||
weak_model = None
|
weak_model = None
|
||||||
|
|
||||||
def __init__(self, model):
|
def __init__(self, model, weak_model=None):
|
||||||
self.name = model
|
self.name = model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.info = litellm.get_model_info(model)
|
self.info = litellm.get_model_info(model)
|
||||||
except Exception as err:
|
except Exception:
|
||||||
raise NoModelInfo(str(err))
|
raise NoModelInfo(model)
|
||||||
|
|
||||||
if self.info.get("max_input_tokens", 0) < 32 * 1024:
|
if self.info.get("max_input_tokens", 0) < 32 * 1024:
|
||||||
self.max_chat_history_tokens = 1024
|
self.max_chat_history_tokens = 1024
|
||||||
|
@ -136,6 +135,7 @@ class Model:
|
||||||
self.max_chat_history_tokens = 2 * 1024
|
self.max_chat_history_tokens = 2 * 1024
|
||||||
|
|
||||||
self.configure_model_settings(model)
|
self.configure_model_settings(model)
|
||||||
|
self.get_weak_model(weak_model)
|
||||||
|
|
||||||
def configure_model_settings(self, model):
|
def configure_model_settings(self, model):
|
||||||
for ms in MODEL_SETTINGS:
|
for ms in MODEL_SETTINGS:
|
||||||
|
@ -159,13 +159,25 @@ class Model:
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def get_weak_model(self):
|
def get_weak_model(self, provided_weak_model_name):
|
||||||
if not self.weak_model:
|
# If weak_model_name is provided, override the model settings
|
||||||
self.weak_model = Model(self.weak_model_name)
|
if provided_weak_model_name:
|
||||||
|
self.weak_model_name = provided_weak_model_name
|
||||||
|
|
||||||
|
if self.weak_model_name == self.name:
|
||||||
|
self.weak_model = self
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.info = litellm.get_model_info(self.weak_model_name)
|
||||||
|
except Exception:
|
||||||
|
raise NoModelInfo(self.weak_model_name)
|
||||||
|
|
||||||
|
self.weak_model = Model(self.weak_model_name)
|
||||||
return self.weak_model
|
return self.weak_model
|
||||||
|
|
||||||
def commit_message_models(self):
|
def commit_message_models(self):
|
||||||
return [self.get_weak_model()]
|
return [self.weak_model]
|
||||||
|
|
||||||
def tokenizer(self, text):
|
def tokenizer(self, text):
|
||||||
return litellm.encode(model=self.name, text=text)
|
return litellm.encode(model=self.name, text=text)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue