choose whether/which token map to use based on tokenized size

This commit is contained in:
Paul Gauthier 2023-05-23 15:40:32 -07:00
parent 2238900b34
commit 7c1112ab20
2 changed files with 47 additions and 27 deletions

View file

@ -10,6 +10,7 @@ from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from pathlib import Path
import tiktoken
import git
import openai
@ -74,6 +75,7 @@ class Coder:
self.pretty = pretty
self.show_diffs = show_diffs
self.tokenizer = tiktoken.encoding_for_model(self.main_model)
def find_common_root(self):
if self.abs_fnames:
@ -169,37 +171,25 @@ class Coder:
files_content += self.get_files_content()
all_content += files_content
if self.repo is not None:
other_files = set(self.get_all_abs_files()) - set(self.abs_fnames)
if other_files:
if self.use_ctags:
files_listing = get_tags_map(other_files)
ctags_msg = " with selected ctags info"
else:
files_listing = "\n".join(
self.get_rel_fname(ofn) for ofn in sorted(other_files)
)
ctags_msg = ""
res = self.choose_files_listing()
if res:
files_listing, ctags_msg = res
if self.abs_fnames:
other = "other "
else:
other = ""
if self.abs_fnames:
other = "other "
else:
other = ""
repo_content = prompts.repo_content_prefix.format(
other=other,
ctags_msg=ctags_msg,
)
repo_content += files_listing
repo_content = prompts.repo_content_prefix.format(
other=other,
ctags_msg=ctags_msg,
)
repo_content += files_listing
from .dump import dump
if all_content:
all_content += "\n\n"
dump(len(repo_content))
if all_content:
all_content += "\n\n"
all_content += repo_content
all_content += repo_content
files_messages = [
dict(role="user", content=all_content),
@ -212,6 +202,35 @@ class Coder:
return files_messages
def choose_files_listing(self):
# 1/4 of gpt-4's context window
max_map_tokens = 2048
if not self.repo:
return
other_files = set(self.get_all_abs_files()) - set(self.abs_fnames)
if not other_files:
return
if self.use_ctags:
files_listing = get_tags_map(other_files)
if self.token_count(files_listing) < max_map_tokens:
ctags_msg = " with selected ctags info"
return files_listing, ctags_msg
files_listing = self.get_simple_files_map(other_files)
ctags_msg = ""
if self.token_count(files_listing) < max_map_tokens:
return files_listing, ctags_msg
def get_simple_files_map(self, other_files):
files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files))
return files_listing
def token_count(self, string):
return len(self.tokenizer.encode(string))
def run(self):
self.done_messages = []
self.cur_messages = []

View file

@ -23,3 +23,4 @@ urllib3==2.0.2
wcwidth==0.2.6
yarl==1.9.2
pytest==7.3.1
tiktoken==0.4.0