mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 00:35:00 +00:00
366 lines
11 KiB
Python
366 lines
11 KiB
Python
import colorsys
|
|
import json
|
|
import os
|
|
import random
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from collections import Counter, defaultdict
|
|
|
|
# import shelve
|
|
import networkx as nx
|
|
import tiktoken
|
|
from pygments.lexers import guess_lexer_for_filename
|
|
from pygments.token import Token
|
|
from pygments.util import ClassNotFound
|
|
|
|
from aider import prompts
|
|
|
|
from .dump import dump # noqa: F402
|
|
|
|
|
|
def to_tree(tags):
|
|
tags = sorted(tags)
|
|
|
|
output = ""
|
|
last = [None] * len(tags[0])
|
|
tab = "\t"
|
|
for tag in tags:
|
|
tag = list(tag)
|
|
|
|
for i in range(len(last)):
|
|
if last[i] != tag[i]:
|
|
break
|
|
|
|
num_common = i
|
|
indent = tab * num_common
|
|
rest = tag[num_common:]
|
|
for item in rest:
|
|
output += indent + item + "\n"
|
|
indent += tab
|
|
last = tag
|
|
|
|
return output
|
|
|
|
|
|
def fname_to_components(fname, with_colon):
|
|
path_components = fname.split(os.sep)
|
|
res = [pc + os.sep for pc in path_components[:-1]]
|
|
if with_colon:
|
|
res.append(path_components[-1] + ":")
|
|
else:
|
|
res.append(path_components[-1])
|
|
return res
|
|
|
|
|
|
class RepoMap:
|
|
ctags_cmd = ["ctags", "--fields=+S", "--extras=-F", "--output-format=json"]
|
|
IDENT_CACHE_FILE = ".aider.ident.cache"
|
|
TAGS_CACHE_FILE = ".aider.tags.cache"
|
|
|
|
# 1/4 of gpt-4's context window
|
|
max_map_tokens = 512
|
|
|
|
def __init__(self, use_ctags=None, root=None, main_model="gpt-4", io=None):
|
|
self.io = io
|
|
|
|
if not root:
|
|
root = os.getcwd()
|
|
self.root = root
|
|
|
|
self.load_ident_cache()
|
|
self.load_tags_cache()
|
|
|
|
if use_ctags is None:
|
|
self.use_ctags = self.check_for_ctags()
|
|
else:
|
|
self.use_ctags = use_ctags
|
|
|
|
self.tokenizer = tiktoken.encoding_for_model(main_model)
|
|
|
|
def get_repo_map(self, chat_files, other_files):
|
|
res = self.choose_files_listing(chat_files, other_files)
|
|
if not res:
|
|
return
|
|
|
|
files_listing, ctags_msg = res
|
|
|
|
if chat_files:
|
|
other = "other "
|
|
else:
|
|
other = ""
|
|
|
|
repo_content = prompts.repo_content_prefix.format(
|
|
other=other,
|
|
ctags_msg=ctags_msg,
|
|
)
|
|
repo_content += files_listing
|
|
|
|
return repo_content
|
|
|
|
def choose_files_listing(self, chat_files, other_files):
|
|
if not other_files:
|
|
return
|
|
|
|
if self.use_ctags:
|
|
files_listing = self.get_ranked_tags_map(chat_files, other_files)
|
|
num_tokens = self.token_count(files_listing)
|
|
if self.io:
|
|
self.io.tool_output(f"ctags map: {num_tokens/1024:.1f} k-tokens")
|
|
ctags_msg = " with selected ctags info"
|
|
return files_listing, ctags_msg
|
|
|
|
files_listing = self.get_simple_files_map(other_files)
|
|
ctags_msg = ""
|
|
num_tokens = self.token_count(files_listing)
|
|
if self.io:
|
|
self.io.tool_output(f"simple map: {num_tokens/1024:.1f} k-tokens")
|
|
if num_tokens < self.max_map_tokens:
|
|
return files_listing, ctags_msg
|
|
|
|
def get_simple_files_map(self, other_files):
|
|
fnames = []
|
|
for fname in other_files:
|
|
fname = self.get_rel_fname(fname)
|
|
fname = fname_to_components(fname, False)
|
|
fnames.append(fname)
|
|
|
|
return to_tree(fnames)
|
|
|
|
def token_count(self, string):
|
|
return len(self.tokenizer.encode(string))
|
|
|
|
def get_rel_fname(self, fname):
|
|
return os.path.relpath(fname, self.root)
|
|
|
|
def split_path(self, path):
|
|
path = os.path.relpath(path, self.root)
|
|
return [path + ":"]
|
|
|
|
def run_ctags(self, filename):
|
|
# Check if the file is in the cache and if the modification time has not changed
|
|
file_mtime = os.path.getmtime(filename)
|
|
cache_key = filename
|
|
if cache_key in self.TAGS_CACHE and self.TAGS_CACHE[cache_key]["mtime"] == file_mtime:
|
|
return self.TAGS_CACHE[cache_key]["data"]
|
|
|
|
cmd = self.ctags_cmd + [filename]
|
|
output = subprocess.check_output(cmd).decode("utf-8")
|
|
output = output.splitlines()
|
|
|
|
data = [json.loads(line) for line in output]
|
|
|
|
# Update the cache
|
|
self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data}
|
|
self.save_tags_cache()
|
|
return data
|
|
|
|
def check_for_ctags(self):
|
|
try:
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
hello_py = os.path.join(tempdir, "hello.py")
|
|
with open(hello_py, "w") as f:
|
|
f.write("def hello():\n print('Hello, world!')\n")
|
|
self.run_ctags(hello_py)
|
|
except Exception:
|
|
return False
|
|
return True
|
|
|
|
def load_tags_cache(self):
|
|
self.TAGS_CACHE = dict() # shelve.open(self.TAGS_CACHE_FILE)
|
|
|
|
def save_tags_cache(self):
|
|
# self.TAGS_CACHE.sync()
|
|
pass
|
|
|
|
def load_ident_cache(self):
|
|
self.IDENT_CACHE = dict() # shelve.open(self.IDENT_CACHE_FILE)
|
|
|
|
def save_ident_cache(self):
|
|
# self.IDENT_CACHE.sync()
|
|
pass
|
|
|
|
def get_name_identifiers(self, fname, uniq=True):
|
|
file_mtime = os.path.getmtime(fname)
|
|
cache_key = fname
|
|
if cache_key in self.IDENT_CACHE and self.IDENT_CACHE[cache_key]["mtime"] == file_mtime:
|
|
idents = self.IDENT_CACHE[cache_key]["data"]
|
|
else:
|
|
idents = self.get_name_identifiers_uncached(fname)
|
|
self.IDENT_CACHE[cache_key] = {"mtime": file_mtime, "data": idents}
|
|
self.save_ident_cache()
|
|
|
|
if uniq:
|
|
idents = set(idents)
|
|
return idents
|
|
|
|
def get_name_identifiers_uncached(self, fname):
|
|
try:
|
|
with open(fname, "r") as f:
|
|
content = f.read()
|
|
except UnicodeDecodeError:
|
|
return list()
|
|
|
|
try:
|
|
lexer = guess_lexer_for_filename(fname, content)
|
|
except ClassNotFound:
|
|
return list()
|
|
|
|
# lexer.get_tokens_unprocessed() returns (char position in file, token type, token string)
|
|
tokens = list(lexer.get_tokens_unprocessed(content))
|
|
res = [token[2] for token in tokens if token[1] in Token.Name]
|
|
return res
|
|
|
|
def get_ranked_tags(self, chat_fnames, other_fnames):
|
|
defines = defaultdict(set)
|
|
references = defaultdict(list)
|
|
definitions = defaultdict(set)
|
|
|
|
personalization = dict()
|
|
|
|
fnames = set(chat_fnames).union(set(other_fnames))
|
|
chat_rel_fnames = set()
|
|
|
|
for fname in sorted(fnames):
|
|
dump(fname)
|
|
rel_fname = os.path.relpath(fname, self.root)
|
|
|
|
if fname in chat_fnames:
|
|
personalization[rel_fname] = 1.0
|
|
chat_rel_fnames.add(rel_fname)
|
|
|
|
data = self.run_ctags(fname)
|
|
|
|
for tag in data:
|
|
ident = tag["name"]
|
|
defines[ident].add(rel_fname)
|
|
|
|
scope = tag.get("scope")
|
|
kind = tag.get("kind")
|
|
name = tag.get("name")
|
|
signature = tag.get("signature")
|
|
|
|
last = name
|
|
if signature:
|
|
last += " " + signature
|
|
|
|
res = [rel_fname]
|
|
if scope:
|
|
res.append(scope)
|
|
res += [kind, last]
|
|
|
|
key = (rel_fname, ident)
|
|
definitions[key].add(tuple(res))
|
|
# definitions[key].add((rel_fname,))
|
|
|
|
idents = self.get_name_identifiers(fname, uniq=False)
|
|
for ident in idents:
|
|
# dump("ref", fname, ident)
|
|
references[ident].append(rel_fname)
|
|
|
|
idents = set(defines.keys()).intersection(set(references.keys()))
|
|
|
|
G = nx.MultiDiGraph()
|
|
|
|
for ident in idents:
|
|
definers = defines[ident]
|
|
for referencer, num_refs in Counter(references[ident]).items():
|
|
for definer in definers:
|
|
if referencer == definer:
|
|
continue
|
|
G.add_edge(referencer, definer, weight=num_refs, ident=ident)
|
|
|
|
if personalization:
|
|
pers_args = dict(personalization=personalization, dangling=personalization)
|
|
else:
|
|
pers_args = dict()
|
|
|
|
ranked = nx.pagerank(G, weight="weight", **pers_args)
|
|
|
|
top_rank = sorted([(rank, node) for (node, rank) in ranked.items()], reverse=True)
|
|
# Print the PageRank of each node
|
|
for rank, node in top_rank:
|
|
print(f"{rank:.03f} {node}")
|
|
|
|
# distribute the rank from each source node, across all of its out edges
|
|
ranked_definitions = defaultdict(float)
|
|
for src in G.nodes:
|
|
src_rank = ranked[src]
|
|
total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True))
|
|
dump(src, src_rank, total_weight)
|
|
for _src, dst, data in G.out_edges(src, data=True):
|
|
data["rank"] = src_rank * data["weight"] / total_weight
|
|
ident = data["ident"]
|
|
ranked_definitions[(dst, ident)] += data["rank"]
|
|
|
|
ranked_tags = []
|
|
ranked_definitions = sorted(ranked_definitions.items(), reverse=True, key=lambda x: x[1])
|
|
for (fname, ident), rank in ranked_definitions:
|
|
print(f"{rank:.03f} {fname} {ident}")
|
|
if fname in chat_rel_fnames:
|
|
continue
|
|
ranked_tags += list(definitions.get((fname, ident), []))
|
|
|
|
return ranked_tags
|
|
|
|
def get_ranked_tags_map(self, chat_fnames, other_fnames=None):
|
|
if not other_fnames:
|
|
other_fnames = list()
|
|
|
|
ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames)
|
|
num_tags = len(ranked_tags)
|
|
|
|
lower_bound = 0
|
|
upper_bound = num_tags
|
|
best_tree = None
|
|
|
|
while lower_bound <= upper_bound:
|
|
middle = (lower_bound + upper_bound) // 2
|
|
tree = to_tree(ranked_tags[:middle])
|
|
num_tokens = self.token_count(tree)
|
|
dump(middle, num_tokens)
|
|
|
|
if num_tokens < self.max_map_tokens:
|
|
best_tree = tree
|
|
lower_bound = middle + 1
|
|
else:
|
|
upper_bound = middle - 1
|
|
|
|
return best_tree
|
|
|
|
|
|
def find_py_files(directory):
|
|
if not os.path.isdir(directory):
|
|
return [directory]
|
|
|
|
py_files = []
|
|
for root, dirs, files in os.walk(directory):
|
|
for file in files:
|
|
if file.endswith(".py"):
|
|
py_files.append(os.path.join(root, file))
|
|
return py_files
|
|
|
|
|
|
def get_random_color():
|
|
hue = random.random()
|
|
r, g, b = [int(x * 255) for x in colorsys.hsv_to_rgb(hue, 1, 0.75)]
|
|
res = f"#{r:02x}{g:02x}{b:02x}"
|
|
return res
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fnames = sys.argv[1:]
|
|
|
|
fnames = []
|
|
for dname in sys.argv[1:]:
|
|
fnames += find_py_files(dname)
|
|
|
|
fnames = sorted(fnames)
|
|
|
|
root = os.path.commonpath(fnames)
|
|
|
|
rm = RepoMap(root=root)
|
|
repo_map = rm.get_ranked_tags_map(fnames)
|
|
|
|
dump(len(repo_map))
|
|
print(repo_map)
|