Merge branch 'main' into call-graph

This commit is contained in:
Paul Gauthier 2023-05-26 17:07:17 -07:00
commit 1e1feeaa21
9 changed files with 222 additions and 92 deletions

View file

@ -3,6 +3,7 @@ import json
import sys
import subprocess
import tiktoken
import tempfile
from collections import defaultdict
from aider import prompts, utils
@ -48,14 +49,20 @@ def fname_to_components(fname, with_colon):
class RepoMap:
def __init__(self, use_ctags=True, root=None, main_model="gpt-4"):
ctags_cmd = ["ctags", "--fields=+S", "--extras=-F", "--output-format=json"]
def __init__(self, use_ctags=None, root=None, main_model="gpt-4"):
if not root:
root = os.getcwd()
self.use_ctags = use_ctags
self.tokenizer = tiktoken.encoding_for_model(main_model)
self.root = root
if use_ctags is None:
self.use_ctags = self.check_for_ctags()
else:
self.use_ctags = use_ctags
self.tokenizer = tiktoken.encoding_for_model(main_model)
def get_repo_map(self, chat_files, other_files):
res = self.choose_files_listing(other_files)
if not res:
@ -123,7 +130,7 @@ class RepoMap:
def split_path(self, path):
path = os.path.relpath(path, self.root)
return fname_to_components(path, True)
return [path + ":"]
def run_ctags(self, filename):
# Check if the file is in the cache and if the modification time has not changed
@ -132,7 +139,7 @@ class RepoMap:
if cache_key in TAGS_CACHE and TAGS_CACHE[cache_key]["mtime"] == file_mtime:
return TAGS_CACHE[cache_key]["data"]
cmd = ["ctags", "--fields=+S", "--extras=-F", "--output-format=json", filename]
cmd = self.ctags_cmd + [filename]
output = subprocess.check_output(cmd).decode("utf-8")
output = output.splitlines()
@ -169,6 +176,17 @@ class RepoMap:
return tags
def check_for_ctags(self):
try:
with tempfile.TemporaryDirectory() as tempdir:
hello_py = os.path.join(tempdir, "hello.py")
with open(hello_py, "w") as f:
f.write("def hello():\n print('Hello, world!')\n")
self.get_tags(hello_py)
except Exception:
return False
return True
def find_py_files(directory):
if not os.path.isdir(directory):
@ -197,6 +215,7 @@ def call_map():
"""
rm = RepoMap()
# res = rm.get_tags_map(fnames)
# print(res)
@ -222,7 +241,7 @@ def call_map():
# dump("ref", fname, ident)
references[ident].append(show_fname)
for ident,fname in defines.items():
for ident, fname in defines.items():
dump(fname, ident)
idents = set(defines.keys()).intersection(set(references.keys()))
@ -256,7 +275,9 @@ def call_map():
ranked = nx.pagerank(G, weight="weight")
# drop low weight edges for plotting
edges_to_remove = [(node1, node2) for node1, node2, data in G.edges(data=True) if data['weight'] < 1]
edges_to_remove = [
(node1, node2) for node1, node2, data in G.edges(data=True) if data["weight"] < 1
]
G.remove_edges_from(edges_to_remove)
# Remove isolated nodes (nodes with no edges)
dump(G.nodes())
@ -272,8 +293,8 @@ def call_map():
dot.node(fname, penwidth=str(pen))
max_w = max(edges.values())
for refs,defs,data in G.edges(data=True):
weight = data['weight']
for refs, defs, data in G.edges(data=True):
weight = data["weight"]
r = random.randint(0, 255)
g = random.randint(0, 255)
@ -286,7 +307,7 @@ def call_map():
print()
print(name)
for ident in sorted(labels[name]):
print('\t', ident)
print("\t", ident)
# print(f"{refs} -{weight}-> {defs}")
top_rank = sorted([(rank, node) for (node, rank) in ranked.items()], reverse=True)
@ -296,5 +317,6 @@ def call_map():
dot.render("tmp", format="pdf", view=True)
if __name__ == "__main__":
call_map()