mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-25 06:44:59 +00:00
roughed in tokenizer, dropped openai, openrouter
This commit is contained in:
parent
855e787175
commit
c9bb22d6d5
6 changed files with 27 additions and 46 deletions
|
@ -68,7 +68,7 @@ class Coder:
|
||||||
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
|
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
|
||||||
|
|
||||||
if not main_model:
|
if not main_model:
|
||||||
main_model = models.Model.create(models.DEFAULT_MODEL_NAME)
|
main_model = models.Model(models.DEFAULT_MODEL_NAME)
|
||||||
|
|
||||||
if edit_format is None:
|
if edit_format is None:
|
||||||
edit_format = main_model.edit_format
|
edit_format = main_model.edit_format
|
||||||
|
@ -214,7 +214,7 @@ class Coder:
|
||||||
|
|
||||||
self.summarizer = ChatSummary(
|
self.summarizer = ChatSummary(
|
||||||
self.client,
|
self.client,
|
||||||
models.Model.weak_model(),
|
self.main_model.weak_model(),
|
||||||
self.main_model.max_chat_history_tokens,
|
self.main_model.max_chat_history_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from aider.sendchat import simple_send_with_retries
|
||||||
|
|
||||||
|
|
||||||
class ChatSummary:
|
class ChatSummary:
|
||||||
def __init__(self, client, model=models.Model.weak_model(), max_tokens=1024):
|
def __init__(self, client, model=None, max_tokens=1024):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.tokenizer = model.tokenizer
|
self.tokenizer = model.tokenizer
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
@ -21,7 +21,7 @@ class ChatSummary:
|
||||||
def tokenize(self, messages):
|
def tokenize(self, messages):
|
||||||
sized = []
|
sized = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
tokens = len(self.tokenizer.encode(json.dumps(msg)))
|
tokens = len(self.tokenizer(json.dumps(msg)))
|
||||||
sized.append((tokens, msg))
|
sized.append((tokens, msg))
|
||||||
return sized
|
return sized
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class ChatSummary:
|
||||||
summary = self.summarize_all(head)
|
summary = self.summarize_all(head)
|
||||||
|
|
||||||
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
|
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
|
||||||
summary_tokens = len(self.tokenizer.encode(json.dumps(summary)))
|
summary_tokens = len(self.tokenizer(json.dumps(summary)))
|
||||||
|
|
||||||
result = summary + tail
|
result = summary + tail
|
||||||
if summary_tokens + tail_tokens < self.max_tokens:
|
if summary_tokens + tail_tokens < self.max_tokens:
|
||||||
|
|
|
@ -188,7 +188,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
const=default_4_turbo_model,
|
const=default_4_turbo_model,
|
||||||
help=f"Use {default_4_turbo_model} model for the main chat",
|
help=f"Use {default_4_turbo_model} model for the main chat",
|
||||||
)
|
)
|
||||||
default_3_model = models.GPT35_0125
|
default_3_model_name = "gpt-3.5-turbo-0125"
|
||||||
core_group.add_argument(
|
core_group.add_argument(
|
||||||
"--35turbo",
|
"--35turbo",
|
||||||
"--35-turbo",
|
"--35-turbo",
|
||||||
|
@ -196,8 +196,8 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
"-3",
|
"-3",
|
||||||
action="store_const",
|
action="store_const",
|
||||||
dest="model",
|
dest="model",
|
||||||
const=default_3_model.name,
|
const=default_3_model_name,
|
||||||
help=f"Use {default_3_model.name} model for the main chat",
|
help=f"Use {default_3_model_name} model for the main chat",
|
||||||
)
|
)
|
||||||
core_group.add_argument(
|
core_group.add_argument(
|
||||||
"--voice-language",
|
"--voice-language",
|
||||||
|
@ -580,7 +580,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
io.tool_error(f"Unknown model {args.model}.")
|
io.tool_error(f"Unknown model {args.model}.")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
main_model = models.Model.create(args.model, None)
|
main_model = models.Model(args.model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
coder = Coder.create(
|
coder = Coder.create(
|
||||||
|
|
|
@ -1,17 +1,5 @@
|
||||||
from .model import Model
|
from .model import Model
|
||||||
from .openai import OpenAIModel
|
|
||||||
from .openrouter import OpenRouterModel
|
|
||||||
|
|
||||||
GPT4 = Model.create("gpt-4")
|
|
||||||
GPT35 = Model.create("gpt-3.5-turbo")
|
|
||||||
GPT35_0125 = Model.create("gpt-3.5-turbo-0125")
|
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
|
DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [Model, DEFAULT_MODEL_NAME]
|
||||||
OpenAIModel,
|
|
||||||
OpenRouterModel,
|
|
||||||
GPT4,
|
|
||||||
GPT35,
|
|
||||||
GPT35_0125,
|
|
||||||
]
|
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import litellm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
name = None
|
name = None
|
||||||
edit_format = None
|
edit_format = "whole"
|
||||||
max_context_tokens = 0
|
max_context_tokens = 0
|
||||||
tokenizer = None
|
|
||||||
max_chat_history_tokens = 1024
|
max_chat_history_tokens = 1024
|
||||||
|
|
||||||
always_available = False
|
always_available = False
|
||||||
|
@ -18,29 +18,24 @@ class Model:
|
||||||
prompt_price = None
|
prompt_price = None
|
||||||
completion_price = None
|
completion_price = None
|
||||||
|
|
||||||
@classmethod
|
def __init__(self, model):
|
||||||
def create(cls, name, client=None):
|
self.name = model
|
||||||
from .openai import OpenAIModel
|
|
||||||
from .openrouter import OpenRouterModel
|
|
||||||
|
|
||||||
if client and client.base_url.host == "openrouter.ai":
|
|
||||||
return OpenRouterModel(client, name)
|
|
||||||
return OpenAIModel(name)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
@staticmethod
|
def weak_model(self):
|
||||||
def strong_model():
|
model = "gpt-3.5-turbo-0125"
|
||||||
return Model.create("gpt-4-0613")
|
if self.name == model:
|
||||||
|
return self
|
||||||
|
|
||||||
@staticmethod
|
return Model(model)
|
||||||
def weak_model():
|
|
||||||
return Model.create("gpt-3.5-turbo-0125")
|
|
||||||
|
|
||||||
@staticmethod
|
def commit_message_models(self):
|
||||||
def commit_message_models():
|
return [self.weak_model()]
|
||||||
return [Model.weak_model()]
|
|
||||||
|
def tokenizer(self, text):
|
||||||
|
return litellm.encode(model=self.name, text=text)
|
||||||
|
|
||||||
def token_count(self, messages):
|
def token_count(self, messages):
|
||||||
if not self.tokenizer:
|
if not self.tokenizer:
|
||||||
|
@ -51,7 +46,7 @@ class Model:
|
||||||
else:
|
else:
|
||||||
msgs = json.dumps(messages)
|
msgs = json.dumps(messages)
|
||||||
|
|
||||||
return len(self.tokenizer.encode(msgs))
|
return len(self.tokenizer(msgs))
|
||||||
|
|
||||||
def token_count_for_image(self, fname):
|
def token_count_for_image(self, fname):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -15,8 +15,6 @@ from pygments.util import ClassNotFound
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from tree_sitter_languages import get_language, get_parser
|
from tree_sitter_languages import get_language, get_parser
|
||||||
|
|
||||||
from aider import models
|
|
||||||
|
|
||||||
from .dump import dump # noqa: F402
|
from .dump import dump # noqa: F402
|
||||||
|
|
||||||
Tag = namedtuple("Tag", "rel_fname fname line name kind".split())
|
Tag = namedtuple("Tag", "rel_fname fname line name kind".split())
|
||||||
|
@ -34,7 +32,7 @@ class RepoMap:
|
||||||
self,
|
self,
|
||||||
map_tokens=1024,
|
map_tokens=1024,
|
||||||
root=None,
|
root=None,
|
||||||
main_model=models.Model.strong_model(),
|
main_model=None,
|
||||||
io=None,
|
io=None,
|
||||||
repo_content_prefix=None,
|
repo_content_prefix=None,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
|
@ -88,7 +86,7 @@ class RepoMap:
|
||||||
return repo_content
|
return repo_content
|
||||||
|
|
||||||
def token_count(self, string):
|
def token_count(self, string):
|
||||||
return len(self.tokenizer.encode(string))
|
return len(self.tokenizer(string))
|
||||||
|
|
||||||
def get_rel_fname(self, fname):
|
def get_rel_fname(self, fname):
|
||||||
return os.path.relpath(fname, self.root)
|
return os.path.relpath(fname, self.root)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue