This commit is contained in:
Paul Gauthier 2023-08-20 15:11:12 -07:00
parent 189446b04e
commit 057aa0a736

View file

@ -8,6 +8,7 @@ from pathlib import Path
import networkx as nx
import tiktoken
from diskcache import Cache
from grep_ast import TreeContext
from tqdm import tqdm
from tree_sitter_languages import get_language, get_parser
@ -16,7 +17,7 @@ from aider.parsers import filename_to_lang
from .dump import dump # noqa: F402
Tag = namedtuple("Tag", "fname line name kind".split())
Tag = namedtuple("Tag", "fname rel_fname line name kind".split())
def to_tree(tags):
@ -25,26 +26,41 @@ def to_tree(tags):
tags = sorted(tags)
cur_fname = None
context = None
output = ""
last = [None] * len(tags[0])
tab = "\t"
for tag in tags:
tag = list(tag)
if type(tag) is tuple:
this_fname = tag[0]
else:
this_fname = tag.rel_fname
for i in range(len(last) + 1):
if i == len(last):
break
if last[i] != tag[i]:
break
if this_fname != cur_fname:
if context:
context.add_context()
output += cur_fname + ":\n"
output += context.format()
context = None
elif cur_fname:
output += cur_fname + "\n"
num_common = i
if type(tag) is not tuple:
context = TreeContext(
tag.rel_fname,
Path(tag.fname).read_text(), # TODO: encoding
color=False,
line_number=False,
child_context=False,
last_line=False,
margin=0,
mark_lois=False,
header_pad=1,
loi_pad=0,
)
cur_fname = this_fname
indent = tab * num_common
rest = tag[num_common:]
for item in rest:
output += indent + str(item) + "\n"
indent += tab
last = tag
if context:
context.add_lines_of_interest([tag.line])
return output
@ -151,7 +167,7 @@ class RepoMap:
except FileNotFoundError:
self.io.tool_error(f"File not found error: {fname}")
def get_tags(self, fname):
def get_tags(self, fname, rel_fname):
lang = filename_to_lang(fname)
if not lang:
return
@ -168,8 +184,8 @@ class RepoMap:
return
query_scm = query_scm.read_text()
code = Path(fname).read_text()
tree = parser.parse(bytes(code, "utf8"))
code = Path(fname).read_text() # TODO: encoding
tree = parser.parse(bytes(code, "utf-8"))
# Run the tags queries
query = language.query(query_scm)
@ -186,9 +202,10 @@ class RepoMap:
continue
result = Tag(
name=node.text.decode("utf-8"), # TODO: encoding?
kind=kind,
fname=fname,
rel_fname=rel_fname,
name=node.text.decode("utf-8"),
kind=kind,
line=node.start_point[0],
)
@ -222,7 +239,7 @@ class RepoMap:
personalization[rel_fname] = 1.0
chat_rel_fnames.add(rel_fname)
tags = self.get_tags(fname)
tags = self.get_tags(fname, rel_fname)
if tags is None:
continue
@ -235,8 +252,9 @@ class RepoMap:
if tag.kind == "ref":
references[tag.name].append(rel_fname)
dump(definitions)
dump(references)
##
# dump(definitions)
# dump(references)
idents = set(defines.keys()).intersection(set(references.keys()))