refactor tokenizer

This commit is contained in:
Paul Gauthier 2024-04-19 12:08:35 -07:00
parent 6f1cebc4c2
commit 547ae142ba
3 changed files with 4 additions and 9 deletions

View file

@ -27,7 +27,6 @@ class Commands:
voice_language = None voice_language = None
self.voice_language = voice_language self.voice_language = voice_language
self.tokenizer = coder.main_model.tokenizer
def cmd_web(self, args): def cmd_web(self, args):
"Use headless selenium to scrape a webpage and add the content to the chat" "Use headless selenium to scrape a webpage and add the content to the chat"

View file

@ -1,5 +1,4 @@
import argparse import argparse
import json
from aider import models, prompts from aider import models, prompts
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
@ -8,7 +7,7 @@ from aider.sendchat import simple_send_with_retries
class ChatSummary: class ChatSummary:
def __init__(self, model=None, max_tokens=1024): def __init__(self, model=None, max_tokens=1024):
self.tokenizer = model.tokenizer self.token_count = model.token_count
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.model = model self.model = model
@ -20,7 +19,7 @@ class ChatSummary:
def tokenize(self, messages): def tokenize(self, messages):
sized = [] sized = []
for msg in messages: for msg in messages:
tokens = len(self.tokenizer(json.dumps(msg))) tokens = self.token_count(msg)
sized.append((tokens, msg)) sized.append((tokens, msg))
return sized return sized
@ -60,7 +59,7 @@ class ChatSummary:
summary = self.summarize_all(head) summary = self.summarize_all(head)
tail_tokens = sum(tokens for tokens, msg in sized[split_index:]) tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
summary_tokens = len(self.tokenizer(json.dumps(summary))) summary_tokens = self.token_count(summary)
result = summary + tail result = summary + tail
if summary_tokens + tail_tokens < self.max_tokens: if summary_tokens + tail_tokens < self.max_tokens:

View file

@ -52,7 +52,7 @@ class RepoMap:
self.max_map_tokens = map_tokens self.max_map_tokens = map_tokens
self.tokenizer = main_model.tokenizer self.token_count = main_model.token_count
self.repo_content_prefix = repo_content_prefix self.repo_content_prefix = repo_content_prefix
def get_repo_map(self, chat_files, other_files): def get_repo_map(self, chat_files, other_files):
@ -89,9 +89,6 @@ class RepoMap:
return repo_content return repo_content
def token_count(self, string):
return len(self.tokenizer(string))
def get_rel_fname(self, fname): def get_rel_fname(self, fname):
return os.path.relpath(fname, self.root) return os.path.relpath(fname, self.root)