From 5e0ff7627ebe855f5b24308d24de8da54f6c0214 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Wed, 3 Jul 2024 13:17:04 -0300 Subject: [PATCH] Defer loading of networkx --- aider/repomap.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/aider/repomap.py b/aider/repomap.py index 872424b62..e7b8cf93b 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -8,7 +8,6 @@ from collections import Counter, defaultdict, namedtuple from importlib import resources from pathlib import Path -import networkx as nx from diskcache import Cache from grep_ast import TreeContext, filename_to_lang from pygments.lexers import guess_lexer_for_filename @@ -32,6 +31,7 @@ class RepoMap: cache_missing = False warned_files = set() + nx = None def __init__( self, @@ -58,6 +58,15 @@ class RepoMap: self.token_count = main_model.token_count self.repo_content_prefix = repo_content_prefix + def import_nx(self): + """Import nx as needed, to avoid 250ms on launch""" + + if self.nx: + return + import networkx as nx + + self.nx = nx + def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None): if self.max_map_tokens <= 0: return @@ -230,6 +239,8 @@ class RepoMap: ) def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents): + self.import_nx() + defines = defaultdict(set) references = defaultdict(list) definitions = defaultdict(set) @@ -295,7 +306,7 @@ class RepoMap: idents = set(defines.keys()).intersection(set(references.keys())) - G = nx.MultiDiGraph() + G = self.nx.MultiDiGraph() for ident in idents: definers = defines[ident] @@ -326,7 +337,7 @@ class RepoMap: pers_args = dict() try: - ranked = nx.pagerank(G, weight="weight", **pers_args) + ranked = self.nx.pagerank(G, weight="weight", **pers_args) except ZeroDivisionError: return []