This commit is contained in:
Paul Gauthier 2023-08-20 15:11:12 -07:00
parent 189446b04e
commit 057aa0a736

View file

@ -8,6 +8,7 @@ from pathlib import Path
import networkx as nx import networkx as nx
import tiktoken import tiktoken
from diskcache import Cache from diskcache import Cache
from grep_ast import TreeContext
from tqdm import tqdm from tqdm import tqdm
from tree_sitter_languages import get_language, get_parser from tree_sitter_languages import get_language, get_parser
@ -16,7 +17,7 @@ from aider.parsers import filename_to_lang
from .dump import dump # noqa: F402 from .dump import dump # noqa: F402
Tag = namedtuple("Tag", "fname line name kind".split()) Tag = namedtuple("Tag", "fname rel_fname line name kind".split())
def to_tree(tags): def to_tree(tags):
@ -25,26 +26,41 @@ def to_tree(tags):
tags = sorted(tags) tags = sorted(tags)
cur_fname = None
context = None
output = "" output = ""
last = [None] * len(tags[0])
tab = "\t"
for tag in tags: for tag in tags:
tag = list(tag) if type(tag) is tuple:
this_fname = tag[0]
else:
this_fname = tag.rel_fname
for i in range(len(last) + 1): if this_fname != cur_fname:
if i == len(last): if context:
break context.add_context()
if last[i] != tag[i]: output += cur_fname + ":\n"
break output += context.format()
context = None
elif cur_fname:
output += cur_fname + "\n"
num_common = i if type(tag) is not tuple:
context = TreeContext(
tag.rel_fname,
Path(tag.fname).read_text(), # TODO: encoding
color=False,
line_number=False,
child_context=False,
last_line=False,
margin=0,
mark_lois=False,
header_pad=1,
loi_pad=0,
)
cur_fname = this_fname
indent = tab * num_common if context:
rest = tag[num_common:] context.add_lines_of_interest([tag.line])
for item in rest:
output += indent + str(item) + "\n"
indent += tab
last = tag
return output return output
@ -151,7 +167,7 @@ class RepoMap:
except FileNotFoundError: except FileNotFoundError:
self.io.tool_error(f"File not found error: {fname}") self.io.tool_error(f"File not found error: {fname}")
def get_tags(self, fname): def get_tags(self, fname, rel_fname):
lang = filename_to_lang(fname) lang = filename_to_lang(fname)
if not lang: if not lang:
return return
@ -168,8 +184,8 @@ class RepoMap:
return return
query_scm = query_scm.read_text() query_scm = query_scm.read_text()
code = Path(fname).read_text() code = Path(fname).read_text() # TODO: encoding
tree = parser.parse(bytes(code, "utf8")) tree = parser.parse(bytes(code, "utf-8"))
# Run the tags queries # Run the tags queries
query = language.query(query_scm) query = language.query(query_scm)
@ -186,9 +202,10 @@ class RepoMap:
continue continue
result = Tag( result = Tag(
name=node.text.decode("utf-8"), # TODO: encoding?
kind=kind,
fname=fname, fname=fname,
rel_fname=rel_fname,
name=node.text.decode("utf-8"),
kind=kind,
line=node.start_point[0], line=node.start_point[0],
) )
@ -222,7 +239,7 @@ class RepoMap:
personalization[rel_fname] = 1.0 personalization[rel_fname] = 1.0
chat_rel_fnames.add(rel_fname) chat_rel_fnames.add(rel_fname)
tags = self.get_tags(fname) tags = self.get_tags(fname, rel_fname)
if tags is None: if tags is None:
continue continue
@ -235,8 +252,9 @@ class RepoMap:
if tag.kind == "ref": if tag.kind == "ref":
references[tag.name].append(rel_fname) references[tag.name].append(rel_fname)
dump(definitions) ##
dump(references) # dump(definitions)
# dump(references)
idents = set(defines.keys()).intersection(set(references.keys())) idents = set(defines.keys()).intersection(set(references.keys()))