mentioned_idents

This commit is contained in:
Paul Gauthier 2024-05-14 12:18:48 -07:00
parent e5616c0247
commit dea93bbfd9

View file

@ -55,13 +55,15 @@ class RepoMap:
self.token_count = main_model.token_count
self.repo_content_prefix = repo_content_prefix
def get_repo_map(self, chat_files, other_files, mentioned_fnames=None):
def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None):
if self.max_map_tokens <= 0:
return
if not other_files:
return
if not mentioned_fnames:
mentioned_fnames = set()
if not mentioned_idents:
mentioned_idents = set()
max_map_tokens = self.max_map_tokens
if not chat_files:
@ -70,7 +72,7 @@ class RepoMap:
try:
files_listing = self.get_ranked_tags_map(
chat_files, other_files, max_map_tokens, mentioned_fnames
chat_files, other_files, max_map_tokens, mentioned_fnames, mentioned_idents
)
except RecursionError:
self.io.tool_error("Disabling repo map, git repo too large?")
@ -217,7 +219,7 @@ class RepoMap:
line=-1,
)
def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames):
def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents):
defines = defaultdict(set)
references = defaultdict(list)
definitions = defaultdict(set)
@ -229,6 +231,10 @@ class RepoMap:
fnames = sorted(fnames)
# Default personalization for unspecified files is 1/num_nodes
# https://networkx.org/documentation/stable/_modules/networkx/algorithms/link_analysis/pagerank_alg.html#pagerank
personalize = 10 / len(fnames)
if self.cache_missing or True:
fnames = tqdm(fnames)
self.cache_missing = False
@ -250,11 +256,11 @@ class RepoMap:
rel_fname = self.get_rel_fname(fname)
if fname in chat_fnames:
personalization[rel_fname] = 1.0
personalization[rel_fname] = personalize
chat_rel_fnames.add(rel_fname)
if fname in mentioned_fnames:
personalization[rel_fname] = 1.0
personalization[rel_fname] = personalize
tags = list(self.get_tags(fname, rel_fname))
if tags is None:
@ -283,11 +289,15 @@ class RepoMap:
for ident in idents:
definers = defines[ident]
if ident in mentioned_idents:
mul = 10
else:
mul = 1
for referencer, num_refs in Counter(references[ident]).items():
for definer in definers:
# if referencer == definer:
# continue
G.add_edge(referencer, definer, weight=num_refs, ident=ident)
G.add_edge(referencer, definer, weight=mul * num_refs, ident=ident)
if not references:
pass
@ -341,7 +351,12 @@ class RepoMap:
return ranked_tags
def get_ranked_tags_map(
self, chat_fnames, other_fnames=None, max_map_tokens=None, mentioned_fnames=None
self,
chat_fnames,
other_fnames=None,
max_map_tokens=None,
mentioned_fnames=None,
mentioned_idents=None,
):
if not other_fnames:
other_fnames = list()
@ -349,8 +364,12 @@ class RepoMap:
max_map_tokens = self.max_map_tokens
if not mentioned_fnames:
mentioned_fnames = set()
if not mentioned_idents:
mentioned_idents = set()
ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames, mentioned_fnames)
ranked_tags = self.get_ranked_tags(
chat_fnames, other_fnames, mentioned_fnames, mentioned_idents
)
num_tags = len(ranked_tags)
lower_bound = 0