Defer loading of networkx

This commit is contained in:
Paul Gauthier 2024-07-03 13:17:04 -03:00
parent b454579cd6
commit 5e0ff7627e

View file

@ -8,7 +8,6 @@ from collections import Counter, defaultdict, namedtuple
from importlib import resources
from pathlib import Path
import networkx as nx
from diskcache import Cache
from grep_ast import TreeContext, filename_to_lang
from pygments.lexers import guess_lexer_for_filename
@ -32,6 +31,7 @@ class RepoMap:
cache_missing = False
warned_files = set()
nx = None
def __init__(
self,
@ -58,6 +58,15 @@ class RepoMap:
self.token_count = main_model.token_count
self.repo_content_prefix = repo_content_prefix
def import_nx(self):
"""Import nx as needed, to avoid 250ms on launch"""
if self.nx:
return
import networkx as nx
self.nx = nx
def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None):
if self.max_map_tokens <= 0:
return
@ -230,6 +239,8 @@ class RepoMap:
)
def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents):
self.import_nx()
defines = defaultdict(set)
references = defaultdict(list)
definitions = defaultdict(set)
@ -295,7 +306,7 @@ class RepoMap:
idents = set(defines.keys()).intersection(set(references.keys()))
G = nx.MultiDiGraph()
G = self.nx.MultiDiGraph()
for ident in idents:
definers = defines[ident]
@ -326,7 +337,7 @@ class RepoMap:
pers_args = dict()
try:
ranked = nx.pagerank(G, weight="weight", **pers_args)
ranked = self.nx.pagerank(G, weight="weight", **pers_args)
except ZeroDivisionError:
return []