mentioned_fnames

This commit is contained in:
Paul Gauthier 2024-05-14 11:52:32 -07:00
parent 73e6949287
commit e5616c0247

View file

@ -55,11 +55,13 @@ class RepoMap:
self.token_count = main_model.token_count
self.repo_content_prefix = repo_content_prefix
def get_repo_map(self, chat_files, other_files):
def get_repo_map(self, chat_files, other_files, mentioned_fnames=None):
if self.max_map_tokens <= 0:
return
if not other_files:
return
if not mentioned_fnames:
mentioned_fnames = set()
max_map_tokens = self.max_map_tokens
if not chat_files:
@ -67,7 +69,9 @@ class RepoMap:
max_map_tokens *= 8
try:
files_listing = self.get_ranked_tags_map(chat_files, other_files, max_map_tokens)
files_listing = self.get_ranked_tags_map(
chat_files, other_files, max_map_tokens, mentioned_fnames
)
except RecursionError:
self.io.tool_error("Disabling repo map, git repo too large?")
self.max_map_tokens = 0
@ -213,7 +217,7 @@ class RepoMap:
line=-1,
)
def get_ranked_tags(self, chat_fnames, other_fnames):
def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames):
defines = defaultdict(set)
references = defaultdict(list)
definitions = defaultdict(set)
@ -249,6 +253,9 @@ class RepoMap:
personalization[rel_fname] = 1.0
chat_rel_fnames.add(rel_fname)
if fname in mentioned_fnames:
personalization[rel_fname] = 1.0
tags = list(self.get_tags(fname, rel_fname))
if tags is None:
continue
@ -265,6 +272,7 @@ class RepoMap:
##
# dump(defines)
# dump(references)
# dump(personalization)
if not references:
references = dict((k, list(v)) for k, v in defines.items())
@ -332,13 +340,17 @@ class RepoMap:
return ranked_tags
def get_ranked_tags_map(self, chat_fnames, other_fnames=None, max_map_tokens=None):
def get_ranked_tags_map(
self, chat_fnames, other_fnames=None, max_map_tokens=None, mentioned_fnames=None
):
if not other_fnames:
other_fnames = list()
if not max_map_tokens:
max_map_tokens = self.max_map_tokens
if not mentioned_fnames:
mentioned_fnames = set()
ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames)
ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames, mentioned_fnames)
num_tags = len(ranked_tags)
lower_bound = 0