roughed in tree-sitter

This commit is contained in:
Paul Gauthier 2023-08-20 14:11:06 -07:00
parent de0cfe4d39
commit 189446b04e

View file

@ -2,18 +2,22 @@ import colorsys
import os import os
import random import random
import sys import sys
from collections import Counter, defaultdict from collections import Counter, defaultdict, namedtuple
from pathlib import Path from pathlib import Path
import networkx as nx import networkx as nx
import tiktoken import tiktoken
from diskcache import Cache from diskcache import Cache
from tqdm import tqdm from tqdm import tqdm
from tree_sitter_languages import get_language, get_parser
from aider import models from aider import models
from aider.parsers import filename_to_lang
from .dump import dump # noqa: F402 from .dump import dump # noqa: F402
Tag = namedtuple("Tag", "fname line name kind".split())
def to_tree(tags): def to_tree(tags):
if not tags: if not tags:
@ -38,7 +42,7 @@ def to_tree(tags):
indent = tab * num_common indent = tab * num_common
rest = tag[num_common:] rest = tag[num_common:]
for item in rest: for item in rest:
output += indent + item + "\n" output += indent + str(item) + "\n"
indent += tab indent += tab
last = tag last = tag
@ -147,6 +151,49 @@ class RepoMap:
except FileNotFoundError: except FileNotFoundError:
self.io.tool_error(f"File not found error: {fname}") self.io.tool_error(f"File not found error: {fname}")
def get_tags(self, fname):
lang = filename_to_lang(fname)
if not lang:
return
language = get_language(lang)
parser = get_parser(lang)
# Load the tags queries
scm_fname = (
f"/Users/gauthier/tmp/py-tree-sitter-languages/queries/tree-sitter-{lang}-tags.scm"
)
query_scm = Path(scm_fname)
if not query_scm.exists():
return
query_scm = query_scm.read_text()
code = Path(fname).read_text()
tree = parser.parse(bytes(code, "utf8"))
# Run the tags queries
query = language.query(query_scm)
captures = query.captures(tree.root_node)
captures = list(captures)
for node, tag in captures:
if tag.startswith("name.definition."):
kind = "def"
elif tag.startswith("name.reference."):
kind = "ref"
else:
continue
result = Tag(
name=node.text.decode("utf-8"), # TODO: encoding?
kind=kind,
fname=fname,
line=node.start_point[0],
)
yield result
def get_ranked_tags(self, chat_fnames, other_fnames): def get_ranked_tags(self, chat_fnames, other_fnames):
defines = defaultdict(set) defines = defaultdict(set)
references = defaultdict(list) references = defaultdict(list)
@ -175,36 +222,21 @@ class RepoMap:
personalization[rel_fname] = 1.0 personalization[rel_fname] = 1.0
chat_rel_fnames.add(rel_fname) chat_rel_fnames.add(rel_fname)
# TODO tags = self.get_tags(fname)
data = [] if tags is None:
continue
for tag in data: for tag in tags:
ident = tag["name"] if tag.kind == "def":
defines[ident].add(rel_fname) defines[tag.name].add(rel_fname)
key = (rel_fname, tag.name)
definitions[key].add(tag)
scope = tag.get("scope") if tag.kind == "ref":
kind = tag.get("kind") references[tag.name].append(rel_fname)
name = tag.get("name")
signature = tag.get("signature")
last = name dump(definitions)
if signature: dump(references)
last += " " + signature
res = [rel_fname]
if scope:
res.append(scope)
res += [kind, last]
key = (rel_fname, ident)
definitions[key].add(tuple(res))
# definitions[key].add((rel_fname,))
# TODO
idents = []
for ident in idents:
# dump("ref", fname, ident)
references[ident].append(rel_fname)
idents = set(defines.keys()).intersection(set(references.keys())) idents = set(defines.keys()).intersection(set(references.keys()))