diff --git a/aider/repomap.py b/aider/repomap.py index 63270d169..e18bf836c 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -55,13 +55,15 @@ class RepoMap: self.token_count = main_model.token_count self.repo_content_prefix = repo_content_prefix - def get_repo_map(self, chat_files, other_files, mentioned_fnames=None): + def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None): if self.max_map_tokens <= 0: return if not other_files: return if not mentioned_fnames: mentioned_fnames = set() + if not mentioned_idents: + mentioned_idents = set() max_map_tokens = self.max_map_tokens if not chat_files: @@ -70,7 +72,7 @@ class RepoMap: try: files_listing = self.get_ranked_tags_map( - chat_files, other_files, max_map_tokens, mentioned_fnames + chat_files, other_files, max_map_tokens, mentioned_fnames, mentioned_idents ) except RecursionError: self.io.tool_error("Disabling repo map, git repo too large?") @@ -217,7 +219,7 @@ class RepoMap: line=-1, ) - def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames): + def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents): defines = defaultdict(set) references = defaultdict(list) definitions = defaultdict(set) @@ -229,6 +231,10 @@ class RepoMap: fnames = sorted(fnames) + # Default personalization for unspecified files is 1/num_nodes + # https://networkx.org/documentation/stable/_modules/networkx/algorithms/link_analysis/pagerank_alg.html#pagerank + personalize = 10 / len(fnames) + if self.cache_missing or True: fnames = tqdm(fnames) self.cache_missing = False @@ -250,11 +256,11 @@ class RepoMap: rel_fname = self.get_rel_fname(fname) if fname in chat_fnames: - personalization[rel_fname] = 1.0 + personalization[rel_fname] = personalize chat_rel_fnames.add(rel_fname) if fname in mentioned_fnames: - personalization[rel_fname] = 1.0 + personalization[rel_fname] = personalize tags = list(self.get_tags(fname, rel_fname)) if tags is None: @@ -283,11 +289,15 @@ class RepoMap: for ident in idents: definers = defines[ident] + if ident in mentioned_idents: + mul = 10 + else: + mul = 1 for referencer, num_refs in Counter(references[ident]).items(): for definer in definers: # if referencer == definer: # continue - G.add_edge(referencer, definer, weight=num_refs, ident=ident) + G.add_edge(referencer, definer, weight=mul * num_refs, ident=ident) if not references: pass @@ -341,7 +351,12 @@ class RepoMap: return ranked_tags def get_ranked_tags_map( - self, chat_fnames, other_fnames=None, max_map_tokens=None, mentioned_fnames=None + self, + chat_fnames, + other_fnames=None, + max_map_tokens=None, + mentioned_fnames=None, + mentioned_idents=None, ): if not other_fnames: other_fnames = list() @@ -349,8 +364,12 @@ class RepoMap: max_map_tokens = self.max_map_tokens if not mentioned_fnames: mentioned_fnames = set() + if not mentioned_idents: + mentioned_idents = set() - ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames, mentioned_fnames) + ranked_tags = self.get_ranked_tags( + chat_fnames, other_fnames, mentioned_fnames, mentioned_idents + ) num_tags = len(ranked_tags) lower_bound = 0