roughed in tokenizer, dropped openai, openrouter

This commit is contained in:
Paul Gauthier 2024-04-17 15:22:35 -07:00
parent 855e787175
commit c9bb22d6d5
6 changed files with 27 additions and 46 deletions

View file

@ -68,7 +68,7 @@ class Coder:
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
if not main_model:
main_model = models.Model.create(models.DEFAULT_MODEL_NAME)
main_model = models.Model(models.DEFAULT_MODEL_NAME)
if edit_format is None:
edit_format = main_model.edit_format
@ -214,7 +214,7 @@ class Coder:
self.summarizer = ChatSummary(
self.client,
models.Model.weak_model(),
self.main_model.weak_model(),
self.main_model.max_chat_history_tokens,
)

View file

@ -7,7 +7,7 @@ from aider.sendchat import simple_send_with_retries
class ChatSummary:
def __init__(self, client, model=models.Model.weak_model(), max_tokens=1024):
def __init__(self, client, model=None, max_tokens=1024):
self.client = client
self.tokenizer = model.tokenizer
self.max_tokens = max_tokens
@ -21,7 +21,7 @@ class ChatSummary:
def tokenize(self, messages):
sized = []
for msg in messages:
tokens = len(self.tokenizer.encode(json.dumps(msg)))
tokens = len(self.tokenizer(json.dumps(msg)))
sized.append((tokens, msg))
return sized
@ -61,7 +61,7 @@ class ChatSummary:
summary = self.summarize_all(head)
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
summary_tokens = len(self.tokenizer.encode(json.dumps(summary)))
summary_tokens = len(self.tokenizer(json.dumps(summary)))
result = summary + tail
if summary_tokens + tail_tokens < self.max_tokens:

View file

@ -188,7 +188,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
const=default_4_turbo_model,
help=f"Use {default_4_turbo_model} model for the main chat",
)
default_3_model = models.GPT35_0125
default_3_model_name = "gpt-3.5-turbo-0125"
core_group.add_argument(
"--35turbo",
"--35-turbo",
@ -196,8 +196,8 @@ def main(argv=None, input=None, output=None, force_git_root=None):
"-3",
action="store_const",
dest="model",
const=default_3_model.name,
help=f"Use {default_3_model.name} model for the main chat",
const=default_3_model_name,
help=f"Use {default_3_model_name} model for the main chat",
)
core_group.add_argument(
"--voice-language",
@ -580,7 +580,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
io.tool_error(f"Unknown model {args.model}.")
return 1
main_model = models.Model.create(args.model, None)
main_model = models.Model(args.model)
try:
coder = Coder.create(

View file

@ -1,17 +1,5 @@
from .model import Model
from .openai import OpenAIModel
from .openrouter import OpenRouterModel
GPT4 = Model.create("gpt-4")
GPT35 = Model.create("gpt-3.5-turbo")
GPT35_0125 = Model.create("gpt-3.5-turbo-0125")
DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
__all__ = [
OpenAIModel,
OpenRouterModel,
GPT4,
GPT35,
GPT35_0125,
]
__all__ = [Model, DEFAULT_MODEL_NAME]

View file

@ -1,14 +1,14 @@
import json
import math
import litellm
from PIL import Image
class Model:
name = None
edit_format = None
edit_format = "whole"
max_context_tokens = 0
tokenizer = None
max_chat_history_tokens = 1024
always_available = False
@ -18,29 +18,24 @@ class Model:
prompt_price = None
completion_price = None
@classmethod
def create(cls, name, client=None):
from .openai import OpenAIModel
from .openrouter import OpenRouterModel
if client and client.base_url.host == "openrouter.ai":
return OpenRouterModel(client, name)
return OpenAIModel(name)
def __init__(self, model):
self.name = model
def __str__(self):
return self.name
@staticmethod
def strong_model():
return Model.create("gpt-4-0613")
def weak_model(self):
model = "gpt-3.5-turbo-0125"
if self.name == model:
return self
@staticmethod
def weak_model():
return Model.create("gpt-3.5-turbo-0125")
return Model(model)
@staticmethod
def commit_message_models():
return [Model.weak_model()]
def commit_message_models(self):
return [self.weak_model()]
def tokenizer(self, text):
return litellm.encode(model=self.name, text=text)
def token_count(self, messages):
if not self.tokenizer:
@ -51,7 +46,7 @@ class Model:
else:
msgs = json.dumps(messages)
return len(self.tokenizer.encode(msgs))
return len(self.tokenizer(msgs))
def token_count_for_image(self, fname):
"""

View file

@ -15,8 +15,6 @@ from pygments.util import ClassNotFound
from tqdm import tqdm
from tree_sitter_languages import get_language, get_parser
from aider import models
from .dump import dump # noqa: F402
Tag = namedtuple("Tag", "rel_fname fname line name kind".split())
@ -34,7 +32,7 @@ class RepoMap:
self,
map_tokens=1024,
root=None,
main_model=models.Model.strong_model(),
main_model=None,
io=None,
repo_content_prefix=None,
verbose=False,
@ -88,7 +86,7 @@ class RepoMap:
return repo_content
def token_count(self, string):
return len(self.tokenizer.encode(string))
return len(self.tokenizer(string))
def get_rel_fname(self, fname):
return os.path.relpath(fname, self.root)