show progress of the repo map

This commit is contained in:
Paul Gauthier 2024-08-05 17:39:21 -03:00
parent 249c85e20f
commit c160a5fa65
2 changed files with 28 additions and 4 deletions

View file

@ -2,7 +2,6 @@ import base64
import itertools
import os
import time
import traceback
from collections import defaultdict
from datetime import datetime
from pathlib import Path
@ -41,7 +40,7 @@ class Spinner:
print(f" {self.text} {next(self.io.spinner_chars)}", end="\r", flush=True)
def end(self):
traceback.print_stack()
# traceback.print_stack()
print(" " * (len(self.text) + 3))

View file

@ -246,7 +246,9 @@ class RepoMap:
line=-1,
)
def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents):
def get_ranked_tags(
self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents, progress=None
):
import networkx as nx
defines = defaultdict(set)
@ -269,6 +271,9 @@ class RepoMap:
self.cache_missing = False
for fname in fnames:
if progress:
progress()
if not Path(fname).is_file():
if fname not in self.warned_files:
if Path(fname).exists():
@ -296,6 +301,9 @@ class RepoMap:
continue
for tag in tags:
if progress:
progress()
if tag.kind == "def":
defines[tag.name].add(rel_fname)
key = (rel_fname, tag.name)
@ -317,6 +325,9 @@ class RepoMap:
G = nx.MultiDiGraph()
for ident in idents:
if progress:
progress()
definers = defines[ident]
if ident in mentioned_idents:
mul = 10
@ -352,6 +363,9 @@ class RepoMap:
# distribute the rank from each source node, across all of its out edges
ranked_definitions = defaultdict(float)
for src in G.nodes:
if progress:
progress()
src_rank = ranked[src]
total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True))
# dump(src, src_rank, total_weight)
@ -404,10 +418,18 @@ class RepoMap:
if not mentioned_idents:
mentioned_idents = set()
spin = self.io.spinner("Preparing repo map")
ranked_tags = self.get_ranked_tags(
chat_fnames, other_fnames, mentioned_fnames, mentioned_idents
chat_fnames,
other_fnames,
mentioned_fnames,
mentioned_idents,
progress=spin.step,
)
spin.step()
num_tags = len(ranked_tags)
lower_bound = 0
upper_bound = num_tags
@ -430,6 +452,8 @@ class RepoMap:
middle = min(max_map_tokens // 50, num_tags)
while lower_bound <= upper_bound:
spin.step()
tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames)
num_tokens = self.token_count(tree)
@ -448,6 +472,7 @@ class RepoMap:
middle = (lower_bound + upper_bound) // 2
spin.end()
return best_tree
tree_cache = dict()