import colorsys import os import random import sys import warnings from collections import Counter, defaultdict, namedtuple from importlib import resources from pathlib import Path import networkx as nx from diskcache import Cache from grep_ast import TreeContext, filename_to_lang from pygments.lexers import guess_lexer_for_filename from pygments.token import Token from pygments.util import ClassNotFound from tqdm import tqdm # tree_sitter is throwing a FutureWarning warnings.simplefilter("ignore", category=FutureWarning) from tree_sitter_languages import get_language, get_parser # noqa: E402 from aider.dump import dump # noqa: F402,E402 Tag = namedtuple("Tag", "rel_fname fname line name kind".split()) class RepoMap: CACHE_VERSION = 3 TAGS_CACHE_DIR = f".aider.tags.cache.v{CACHE_VERSION}" cache_missing = False warned_files = set() def __init__( self, map_tokens=1024, root=None, main_model=None, io=None, repo_content_prefix=None, verbose=False, ): self.io = io self.verbose = verbose if not root: root = os.getcwd() self.root = root self.load_tags_cache() self.max_map_tokens = map_tokens self.token_count = main_model.token_count self.repo_content_prefix = repo_content_prefix def get_repo_map(self, chat_files, other_files): if self.max_map_tokens <= 0: return if not other_files: return max_map_tokens = self.max_map_tokens if not chat_files: # with no code in the chat, give a bigger view of the entire repo max_map_tokens *= 8 try: files_listing = self.get_ranked_tags_map(chat_files, other_files, max_map_tokens) except RecursionError: self.io.tool_error("Disabling repo map, git repo too large?") self.max_map_tokens = 0 return if not files_listing: return num_tokens = self.token_count(files_listing) if self.verbose: self.io.tool_output(f"Repo-map: {num_tokens/1024:.1f} k-tokens") # noqa: E226, E231 if chat_files: other = "other " else: other = "" if self.repo_content_prefix: repo_content = self.repo_content_prefix.format(other=other) else: repo_content = "" repo_content += files_listing return repo_content def get_rel_fname(self, fname): return os.path.relpath(fname, self.root) def split_path(self, path): path = os.path.relpath(path, self.root) return [path + ":"] def load_tags_cache(self): path = Path(self.root) / self.TAGS_CACHE_DIR if not path.exists(): self.cache_missing = True self.TAGS_CACHE = Cache(path) def save_tags_cache(self): pass def get_mtime(self, fname): try: return os.path.getmtime(fname) except FileNotFoundError: self.io.tool_error(f"File not found error: {fname}") def get_tags(self, fname, rel_fname): # Check if the file is in the cache and if the modification time has not changed file_mtime = self.get_mtime(fname) if file_mtime is None: return [] cache_key = fname if cache_key in self.TAGS_CACHE and self.TAGS_CACHE[cache_key]["mtime"] == file_mtime: return self.TAGS_CACHE[cache_key]["data"] # miss! data = list(self.get_tags_raw(fname, rel_fname)) # Update the cache self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data} self.save_tags_cache() return data def get_tags_raw(self, fname, rel_fname): lang = filename_to_lang(fname) if not lang: return language = get_language(lang) parser = get_parser(lang) # Load the tags queries try: scm_fname = resources.files(__package__).joinpath( "queries", f"tree-sitter-{lang}-tags.scm" ) except KeyError: return query_scm = scm_fname if not query_scm.exists(): return query_scm = query_scm.read_text() code = self.io.read_text(fname) if not code: return tree = parser.parse(bytes(code, "utf-8")) # Run the tags queries query = language.query(query_scm) captures = query.captures(tree.root_node) captures = list(captures) saw = set() for node, tag in captures: if tag.startswith("name.definition."): kind = "def" elif tag.startswith("name.reference."): kind = "ref" else: continue saw.add(kind) result = Tag( rel_fname=rel_fname, fname=fname, name=node.text.decode("utf-8"), kind=kind, line=node.start_point[0], ) yield result if "ref" in saw: return if "def" not in saw: return # We saw defs, without any refs # Some tags files only provide defs (cpp, for example) # Use pygments to backfill refs try: lexer = guess_lexer_for_filename(fname, code) except ClassNotFound: return tokens = list(lexer.get_tokens(code)) tokens = [token[1] for token in tokens if token[0] in Token.Name] for token in tokens: yield Tag( rel_fname=rel_fname, fname=fname, name=token, kind="ref", line=-1, ) def get_ranked_tags(self, chat_fnames, other_fnames): defines = defaultdict(set) references = defaultdict(list) definitions = defaultdict(set) personalization = dict() fnames = set(chat_fnames).union(set(other_fnames)) chat_rel_fnames = set() fnames = sorted(fnames) if self.cache_missing: fnames = tqdm(fnames) self.cache_missing = False for fname in fnames: if not Path(fname).is_file(): if fname not in self.warned_files: if Path(fname).exists(): self.io.tool_error( f"Repo-map can't include {fname}, it is not a normal file" ) else: self.io.tool_error(f"Repo-map can't include {fname}, it no longer exists") self.warned_files.add(fname) continue # dump(fname) rel_fname = self.get_rel_fname(fname) if fname in chat_fnames: personalization[rel_fname] = 1.0 chat_rel_fnames.add(rel_fname) tags = list(self.get_tags(fname, rel_fname)) if tags is None: continue for tag in tags: if tag.kind == "def": defines[tag.name].add(rel_fname) key = (rel_fname, tag.name) definitions[key].add(tag) if tag.kind == "ref": references[tag.name].append(rel_fname) ## # dump(defines) # dump(references) if not references: references = dict((k, list(v)) for k, v in defines.items()) idents = set(defines.keys()).intersection(set(references.keys())) G = nx.MultiDiGraph() for ident in idents: definers = defines[ident] for referencer, num_refs in Counter(references[ident]).items(): for definer in definers: # if referencer == definer: # continue G.add_edge(referencer, definer, weight=num_refs, ident=ident) if not references: pass if personalization: pers_args = dict(personalization=personalization, dangling=personalization) else: pers_args = dict() try: ranked = nx.pagerank(G, weight="weight", **pers_args) except ZeroDivisionError: return [] # distribute the rank from each source node, across all of its out edges ranked_definitions = defaultdict(float) for src in G.nodes: src_rank = ranked[src] total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True)) # dump(src, src_rank, total_weight) for _src, dst, data in G.out_edges(src, data=True): data["rank"] = src_rank * data["weight"] / total_weight ident = data["ident"] ranked_definitions[(dst, ident)] += data["rank"] ranked_tags = [] ranked_definitions = sorted(ranked_definitions.items(), reverse=True, key=lambda x: x[1]) # dump(ranked_definitions) for (fname, ident), rank in ranked_definitions: # print(f"{rank:.03f} {fname} {ident}") if fname in chat_rel_fnames: continue ranked_tags += list(definitions.get((fname, ident), [])) rel_other_fnames_without_tags = set(self.get_rel_fname(fname) for fname in other_fnames) fnames_already_included = set(rt[0] for rt in ranked_tags) top_rank = sorted([(rank, node) for (node, rank) in ranked.items()], reverse=True) for rank, fname in top_rank: if fname in rel_other_fnames_without_tags: rel_other_fnames_without_tags.remove(fname) if fname not in fnames_already_included: ranked_tags.append((fname,)) for fname in rel_other_fnames_without_tags: ranked_tags.append((fname,)) return ranked_tags def get_ranked_tags_map(self, chat_fnames, other_fnames=None, max_map_tokens=None): if not other_fnames: other_fnames = list() if not max_map_tokens: max_map_tokens = self.max_map_tokens dump(max_map_tokens, self.max_map_tokens) ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames) # dump(ranked_tags) num_tags = len(ranked_tags) lower_bound = 0 upper_bound = num_tags best_tree = None best_tree_tokens = 0 chat_rel_fnames = [self.get_rel_fname(fname) for fname in chat_fnames] if False: for i in range(num_tags): print("making tree...") tree = self.to_tree(ranked_tags[:i], chat_rel_fnames) print("tokenizing") num_tokens = self.token_count(tree) # print('*'*50) dump(i, num_tokens) # print(tree) while lower_bound <= upper_bound: middle = (lower_bound + upper_bound) // 2 # print("making tree...") tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames) # print("tokenizing") num_tokens = self.token_count(tree) dump(lower_bound, middle, upper_bound) dump(num_tokens) # dump(len(tree)) if num_tokens < max_map_tokens and num_tokens > best_tree_tokens: print("best_tree", num_tokens) best_tree = tree best_tree_tokens = num_tokens if num_tokens < max_map_tokens: lower_bound = middle + 1 else: upper_bound = middle - 1 return best_tree def to_tree(self, tags, chat_rel_fnames): if not tags: return "" tags = [tag for tag in tags if tag[0] not in chat_rel_fnames] tags = sorted(tags) cur_fname = None context = None output = "" # add a bogus tag at the end so we trip the this_fname != cur_fname... dummy_tag = (None,) for tag in tags + [dummy_tag]: this_rel_fname = tag[0] # ... here ... to output the final real entry in the list if this_rel_fname != cur_fname: if context: context.add_context() output += "\n" output += cur_fname + ":\n" output += context.format() context = None elif cur_fname: output += "\n" + cur_fname + "\n" if type(tag) is Tag: code = self.io.read_text(tag.fname) or "" if not code.endswith("\n"): code += "\n" context = TreeContext( tag.rel_fname, code, color=False, line_number=False, child_context=False, last_line=False, margin=0, mark_lois=False, loi_pad=0, # header_max=30, show_top_of_file_parent_scope=False, ) cur_fname = this_rel_fname if context: context.add_lines_of_interest([tag.line]) # truncate long lines, in case we get minified js or something else crazy output = "".join([line[:100] for line in output.splitlines(keepends=True)]) return output def find_src_files(directory): if not os.path.isdir(directory): return [directory] src_files = [] for root, dirs, files in os.walk(directory): for file in files: src_files.append(os.path.join(root, file)) return src_files def get_random_color(): hue = random.random() r, g, b = [int(x * 255) for x in colorsys.hsv_to_rgb(hue, 1, 0.75)] res = f"#{r:02x}{g:02x}{b:02x}" # noqa: E231 return res if __name__ == "__main__": fnames = sys.argv[1:] chat_fnames = [] other_fnames = [] for fname in sys.argv[1:]: if Path(fname).is_dir(): chat_fnames += find_src_files(fname) else: chat_fnames.append(fname) rm = RepoMap(root=".") repo_map = rm.get_ranked_tags_map(chat_fnames, other_fnames) dump(len(repo_map)) print(repo_map)