mentioned_fnames

This commit is contained in:
Paul Gauthier 2024-05-14 11:52:32 -07:00
parent 73e6949287
commit e5616c0247

View file

@ -55,11 +55,13 @@ class RepoMap:
self.token_count = main_model.token_count self.token_count = main_model.token_count
self.repo_content_prefix = repo_content_prefix self.repo_content_prefix = repo_content_prefix
def get_repo_map(self, chat_files, other_files): def get_repo_map(self, chat_files, other_files, mentioned_fnames=None):
if self.max_map_tokens <= 0: if self.max_map_tokens <= 0:
return return
if not other_files: if not other_files:
return return
if not mentioned_fnames:
mentioned_fnames = set()
max_map_tokens = self.max_map_tokens max_map_tokens = self.max_map_tokens
if not chat_files: if not chat_files:
@ -67,7 +69,9 @@ class RepoMap:
max_map_tokens *= 8 max_map_tokens *= 8
try: try:
files_listing = self.get_ranked_tags_map(chat_files, other_files, max_map_tokens) files_listing = self.get_ranked_tags_map(
chat_files, other_files, max_map_tokens, mentioned_fnames
)
except RecursionError: except RecursionError:
self.io.tool_error("Disabling repo map, git repo too large?") self.io.tool_error("Disabling repo map, git repo too large?")
self.max_map_tokens = 0 self.max_map_tokens = 0
@ -213,7 +217,7 @@ class RepoMap:
line=-1, line=-1,
) )
def get_ranked_tags(self, chat_fnames, other_fnames): def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames):
defines = defaultdict(set) defines = defaultdict(set)
references = defaultdict(list) references = defaultdict(list)
definitions = defaultdict(set) definitions = defaultdict(set)
@ -249,6 +253,9 @@ class RepoMap:
personalization[rel_fname] = 1.0 personalization[rel_fname] = 1.0
chat_rel_fnames.add(rel_fname) chat_rel_fnames.add(rel_fname)
if fname in mentioned_fnames:
personalization[rel_fname] = 1.0
tags = list(self.get_tags(fname, rel_fname)) tags = list(self.get_tags(fname, rel_fname))
if tags is None: if tags is None:
continue continue
@ -265,6 +272,7 @@ class RepoMap:
## ##
# dump(defines) # dump(defines)
# dump(references) # dump(references)
# dump(personalization)
if not references: if not references:
references = dict((k, list(v)) for k, v in defines.items()) references = dict((k, list(v)) for k, v in defines.items())
@ -332,13 +340,17 @@ class RepoMap:
return ranked_tags return ranked_tags
def get_ranked_tags_map(self, chat_fnames, other_fnames=None, max_map_tokens=None): def get_ranked_tags_map(
self, chat_fnames, other_fnames=None, max_map_tokens=None, mentioned_fnames=None
):
if not other_fnames: if not other_fnames:
other_fnames = list() other_fnames = list()
if not max_map_tokens: if not max_map_tokens:
max_map_tokens = self.max_map_tokens max_map_tokens = self.max_map_tokens
if not mentioned_fnames:
mentioned_fnames = set()
ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames) ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames, mentioned_fnames)
num_tags = len(ranked_tags) num_tags = len(ranked_tags)
lower_bound = 0 lower_bound = 0