mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-30 01:04:59 +00:00
refactor tokenizer
This commit is contained in:
parent
6f1cebc4c2
commit
547ae142ba
3 changed files with 4 additions and 9 deletions
|
@ -27,7 +27,6 @@ class Commands:
|
||||||
voice_language = None
|
voice_language = None
|
||||||
|
|
||||||
self.voice_language = voice_language
|
self.voice_language = voice_language
|
||||||
self.tokenizer = coder.main_model.tokenizer
|
|
||||||
|
|
||||||
def cmd_web(self, args):
|
def cmd_web(self, args):
|
||||||
"Use headless selenium to scrape a webpage and add the content to the chat"
|
"Use headless selenium to scrape a webpage and add the content to the chat"
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
|
|
||||||
from aider import models, prompts
|
from aider import models, prompts
|
||||||
from aider.dump import dump # noqa: F401
|
from aider.dump import dump # noqa: F401
|
||||||
|
@ -8,7 +7,7 @@ from aider.sendchat import simple_send_with_retries
|
||||||
|
|
||||||
class ChatSummary:
|
class ChatSummary:
|
||||||
def __init__(self, model=None, max_tokens=1024):
|
def __init__(self, model=None, max_tokens=1024):
|
||||||
self.tokenizer = model.tokenizer
|
self.token_count = model.token_count
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
@ -20,7 +19,7 @@ class ChatSummary:
|
||||||
def tokenize(self, messages):
|
def tokenize(self, messages):
|
||||||
sized = []
|
sized = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
tokens = len(self.tokenizer(json.dumps(msg)))
|
tokens = self.token_count(msg)
|
||||||
sized.append((tokens, msg))
|
sized.append((tokens, msg))
|
||||||
return sized
|
return sized
|
||||||
|
|
||||||
|
@ -60,7 +59,7 @@ class ChatSummary:
|
||||||
summary = self.summarize_all(head)
|
summary = self.summarize_all(head)
|
||||||
|
|
||||||
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
|
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
|
||||||
summary_tokens = len(self.tokenizer(json.dumps(summary)))
|
summary_tokens = self.token_count(summary)
|
||||||
|
|
||||||
result = summary + tail
|
result = summary + tail
|
||||||
if summary_tokens + tail_tokens < self.max_tokens:
|
if summary_tokens + tail_tokens < self.max_tokens:
|
||||||
|
|
|
@ -52,7 +52,7 @@ class RepoMap:
|
||||||
|
|
||||||
self.max_map_tokens = map_tokens
|
self.max_map_tokens = map_tokens
|
||||||
|
|
||||||
self.tokenizer = main_model.tokenizer
|
self.token_count = main_model.token_count
|
||||||
self.repo_content_prefix = repo_content_prefix
|
self.repo_content_prefix = repo_content_prefix
|
||||||
|
|
||||||
def get_repo_map(self, chat_files, other_files):
|
def get_repo_map(self, chat_files, other_files):
|
||||||
|
@ -89,9 +89,6 @@ class RepoMap:
|
||||||
|
|
||||||
return repo_content
|
return repo_content
|
||||||
|
|
||||||
def token_count(self, string):
|
|
||||||
return len(self.tokenizer(string))
|
|
||||||
|
|
||||||
def get_rel_fname(self, fname):
|
def get_rel_fname(self, fname):
|
||||||
return os.path.relpath(fname, self.root)
|
return os.path.relpath(fname, self.root)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue