show progress of the repo map

This commit is contained in:
Paul Gauthier 2024-08-05 17:39:21 -03:00
parent 249c85e20f
commit c160a5fa65
2 changed files with 28 additions and 4 deletions

View file

@ -2,7 +2,6 @@ import base64
import itertools import itertools
import os import os
import time import time
import traceback
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -41,7 +40,7 @@ class Spinner:
print(f" {self.text} {next(self.io.spinner_chars)}", end="\r", flush=True) print(f" {self.text} {next(self.io.spinner_chars)}", end="\r", flush=True)
def end(self): def end(self):
traceback.print_stack() # traceback.print_stack()
print(" " * (len(self.text) + 3)) print(" " * (len(self.text) + 3))

View file

@ -246,7 +246,9 @@ class RepoMap:
line=-1, line=-1,
) )
def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents): def get_ranked_tags(
self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents, progress=None
):
import networkx as nx import networkx as nx
defines = defaultdict(set) defines = defaultdict(set)
@ -269,6 +271,9 @@ class RepoMap:
self.cache_missing = False self.cache_missing = False
for fname in fnames: for fname in fnames:
if progress:
progress()
if not Path(fname).is_file(): if not Path(fname).is_file():
if fname not in self.warned_files: if fname not in self.warned_files:
if Path(fname).exists(): if Path(fname).exists():
@ -296,6 +301,9 @@ class RepoMap:
continue continue
for tag in tags: for tag in tags:
if progress:
progress()
if tag.kind == "def": if tag.kind == "def":
defines[tag.name].add(rel_fname) defines[tag.name].add(rel_fname)
key = (rel_fname, tag.name) key = (rel_fname, tag.name)
@ -317,6 +325,9 @@ class RepoMap:
G = nx.MultiDiGraph() G = nx.MultiDiGraph()
for ident in idents: for ident in idents:
if progress:
progress()
definers = defines[ident] definers = defines[ident]
if ident in mentioned_idents: if ident in mentioned_idents:
mul = 10 mul = 10
@ -352,6 +363,9 @@ class RepoMap:
# distribute the rank from each source node, across all of its out edges # distribute the rank from each source node, across all of its out edges
ranked_definitions = defaultdict(float) ranked_definitions = defaultdict(float)
for src in G.nodes: for src in G.nodes:
if progress:
progress()
src_rank = ranked[src] src_rank = ranked[src]
total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True)) total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True))
# dump(src, src_rank, total_weight) # dump(src, src_rank, total_weight)
@ -404,10 +418,18 @@ class RepoMap:
if not mentioned_idents: if not mentioned_idents:
mentioned_idents = set() mentioned_idents = set()
spin = self.io.spinner("Preparing repo map")
ranked_tags = self.get_ranked_tags( ranked_tags = self.get_ranked_tags(
chat_fnames, other_fnames, mentioned_fnames, mentioned_idents chat_fnames,
other_fnames,
mentioned_fnames,
mentioned_idents,
progress=spin.step,
) )
spin.step()
num_tags = len(ranked_tags) num_tags = len(ranked_tags)
lower_bound = 0 lower_bound = 0
upper_bound = num_tags upper_bound = num_tags
@ -430,6 +452,8 @@ class RepoMap:
middle = min(max_map_tokens // 50, num_tags) middle = min(max_map_tokens // 50, num_tags)
while lower_bound <= upper_bound: while lower_bound <= upper_bound:
spin.step()
tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames) tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames)
num_tokens = self.token_count(tree) num_tokens = self.token_count(tree)
@ -448,6 +472,7 @@ class RepoMap:
middle = (lower_bound + upper_bound) // 2 middle = (lower_bound + upper_bound) // 2
spin.end()
return best_tree return best_tree
tree_cache = dict() tree_cache = dict()