diff --git a/aider/repomap.py b/aider/repomap.py index fffa58ec3..c82b944dd 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -2,18 +2,22 @@ import colorsys import os import random import sys -from collections import Counter, defaultdict +from collections import Counter, defaultdict, namedtuple from pathlib import Path import networkx as nx import tiktoken from diskcache import Cache from tqdm import tqdm +from tree_sitter_languages import get_language, get_parser from aider import models +from aider.parsers import filename_to_lang from .dump import dump # noqa: F402 +Tag = namedtuple("Tag", "fname line name kind".split()) + def to_tree(tags): if not tags: @@ -38,7 +42,7 @@ def to_tree(tags): indent = tab * num_common rest = tag[num_common:] for item in rest: - output += indent + item + "\n" + output += indent + str(item) + "\n" indent += tab last = tag @@ -147,6 +151,49 @@ class RepoMap: except FileNotFoundError: self.io.tool_error(f"File not found error: {fname}") + def get_tags(self, fname): + lang = filename_to_lang(fname) + if not lang: + return + + language = get_language(lang) + parser = get_parser(lang) + + # Load the tags queries + scm_fname = ( + f"/Users/gauthier/tmp/py-tree-sitter-languages/queries/tree-sitter-{lang}-tags.scm" + ) + query_scm = Path(scm_fname) + if not query_scm.exists(): + return + query_scm = query_scm.read_text() + + code = Path(fname).read_text() + tree = parser.parse(bytes(code, "utf8")) + + # Run the tags queries + query = language.query(query_scm) + captures = query.captures(tree.root_node) + + captures = list(captures) + + for node, tag in captures: + if tag.startswith("name.definition."): + kind = "def" + elif tag.startswith("name.reference."): + kind = "ref" + else: + continue + + result = Tag( + name=node.text.decode("utf-8"), # TODO: encoding? + kind=kind, + fname=fname, + line=node.start_point[0], + ) + + yield result + def get_ranked_tags(self, chat_fnames, other_fnames): defines = defaultdict(set) references = defaultdict(list) @@ -175,36 +222,21 @@ class RepoMap: personalization[rel_fname] = 1.0 chat_rel_fnames.add(rel_fname) - # TODO - data = [] + tags = self.get_tags(fname) + if tags is None: + continue - for tag in data: - ident = tag["name"] - defines[ident].add(rel_fname) + for tag in tags: + if tag.kind == "def": + defines[tag.name].add(rel_fname) + key = (rel_fname, tag.name) + definitions[key].add(tag) - scope = tag.get("scope") - kind = tag.get("kind") - name = tag.get("name") - signature = tag.get("signature") + if tag.kind == "ref": + references[tag.name].append(rel_fname) - last = name - if signature: - last += " " + signature - - res = [rel_fname] - if scope: - res.append(scope) - res += [kind, last] - - key = (rel_fname, ident) - definitions[key].add(tuple(res)) - # definitions[key].add((rel_fname,)) - - # TODO - idents = [] - for ident in idents: - # dump("ref", fname, ident) - references[ident].append(rel_fname) + dump(definitions) + dump(references) idents = set(defines.keys()).intersection(set(references.keys()))