diff --git a/aider/repomap.py b/aider/repomap.py index ec76755e0..63270d169 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -55,11 +55,13 @@ class RepoMap: self.token_count = main_model.token_count self.repo_content_prefix = repo_content_prefix - def get_repo_map(self, chat_files, other_files): + def get_repo_map(self, chat_files, other_files, mentioned_fnames=None): if self.max_map_tokens <= 0: return if not other_files: return + if not mentioned_fnames: + mentioned_fnames = set() max_map_tokens = self.max_map_tokens if not chat_files: @@ -67,7 +69,9 @@ class RepoMap: max_map_tokens *= 8 try: - files_listing = self.get_ranked_tags_map(chat_files, other_files, max_map_tokens) + files_listing = self.get_ranked_tags_map( + chat_files, other_files, max_map_tokens, mentioned_fnames + ) except RecursionError: self.io.tool_error("Disabling repo map, git repo too large?") self.max_map_tokens = 0 @@ -213,7 +217,7 @@ class RepoMap: line=-1, ) - def get_ranked_tags(self, chat_fnames, other_fnames): + def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames): defines = defaultdict(set) references = defaultdict(list) definitions = defaultdict(set) @@ -249,6 +253,9 @@ class RepoMap: personalization[rel_fname] = 1.0 chat_rel_fnames.add(rel_fname) + if fname in mentioned_fnames: + personalization[rel_fname] = 1.0 + tags = list(self.get_tags(fname, rel_fname)) if tags is None: continue @@ -265,6 +272,7 @@ class RepoMap: ## # dump(defines) # dump(references) + # dump(personalization) if not references: references = dict((k, list(v)) for k, v in defines.items()) @@ -332,13 +340,17 @@ class RepoMap: return ranked_tags - def get_ranked_tags_map(self, chat_fnames, other_fnames=None, max_map_tokens=None): + def get_ranked_tags_map( + self, chat_fnames, other_fnames=None, max_map_tokens=None, mentioned_fnames=None + ): if not other_fnames: other_fnames = list() if not max_map_tokens: max_map_tokens = self.max_map_tokens + if not mentioned_fnames: + mentioned_fnames = set() - ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames) + ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames, mentioned_fnames) num_tags = len(ranked_tags) lower_bound = 0