From c160a5fa65ac8018d17765061bf08ea1ad1fe6d5 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Mon, 5 Aug 2024 17:39:21 -0300 Subject: [PATCH] show progress of the repo map --- aider/io.py | 3 +-- aider/repomap.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/aider/io.py b/aider/io.py index 643378fdb..a6ad16a07 100644 --- a/aider/io.py +++ b/aider/io.py @@ -2,7 +2,6 @@ import base64 import itertools import os import time -import traceback from collections import defaultdict from datetime import datetime from pathlib import Path @@ -41,7 +40,7 @@ class Spinner: print(f" {self.text} {next(self.io.spinner_chars)}", end="\r", flush=True) def end(self): - traceback.print_stack() + # traceback.print_stack() print(" " * (len(self.text) + 3)) diff --git a/aider/repomap.py b/aider/repomap.py index ccbaf94b0..85858c7fe 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -246,7 +246,9 @@ class RepoMap: line=-1, ) - def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents): + def get_ranked_tags( + self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents, progress=None + ): import networkx as nx defines = defaultdict(set) @@ -269,6 +271,9 @@ class RepoMap: self.cache_missing = False for fname in fnames: + if progress: + progress() + if not Path(fname).is_file(): if fname not in self.warned_files: if Path(fname).exists(): @@ -296,6 +301,9 @@ class RepoMap: continue for tag in tags: + if progress: + progress() + if tag.kind == "def": defines[tag.name].add(rel_fname) key = (rel_fname, tag.name) @@ -317,6 +325,9 @@ class RepoMap: G = nx.MultiDiGraph() for ident in idents: + if progress: + progress() + definers = defines[ident] if ident in mentioned_idents: mul = 10 @@ -352,6 +363,9 @@ class RepoMap: # distribute the rank from each source node, across all of its out edges ranked_definitions = defaultdict(float) for src in G.nodes: + if progress: + progress() + src_rank = ranked[src] total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True)) # dump(src, src_rank, total_weight) @@ -404,10 +418,18 @@ class RepoMap: if not mentioned_idents: mentioned_idents = set() + spin = self.io.spinner("Preparing repo map") + ranked_tags = self.get_ranked_tags( - chat_fnames, other_fnames, mentioned_fnames, mentioned_idents + chat_fnames, + other_fnames, + mentioned_fnames, + mentioned_idents, + progress=spin.step, ) + spin.step() + num_tags = len(ranked_tags) lower_bound = 0 upper_bound = num_tags @@ -430,6 +452,8 @@ class RepoMap: middle = min(max_map_tokens // 50, num_tags) while lower_bound <= upper_bound: + spin.step() + tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames) num_tokens = self.token_count(tree) @@ -448,6 +472,7 @@ class RepoMap: middle = (lower_bound + upper_bound) // 2 + spin.end() return best_tree tree_cache = dict()