Defer loading of networkx

This commit is contained in:
Paul Gauthier 2024-07-03 13:17:04 -03:00
parent b454579cd6
commit 5e0ff7627e

View file

@ -8,7 +8,6 @@ from collections import Counter, defaultdict, namedtuple
from importlib import resources from importlib import resources
from pathlib import Path from pathlib import Path
import networkx as nx
from diskcache import Cache from diskcache import Cache
from grep_ast import TreeContext, filename_to_lang from grep_ast import TreeContext, filename_to_lang
from pygments.lexers import guess_lexer_for_filename from pygments.lexers import guess_lexer_for_filename
@ -32,6 +31,7 @@ class RepoMap:
cache_missing = False cache_missing = False
warned_files = set() warned_files = set()
nx = None
def __init__( def __init__(
self, self,
@ -58,6 +58,15 @@ class RepoMap:
self.token_count = main_model.token_count self.token_count = main_model.token_count
self.repo_content_prefix = repo_content_prefix self.repo_content_prefix = repo_content_prefix
def import_nx(self):
"""Import nx as needed, to avoid 250ms on launch"""
if self.nx:
return
import networkx as nx
self.nx = nx
def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None): def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None):
if self.max_map_tokens <= 0: if self.max_map_tokens <= 0:
return return
@ -230,6 +239,8 @@ class RepoMap:
) )
def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents): def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents):
self.import_nx()
defines = defaultdict(set) defines = defaultdict(set)
references = defaultdict(list) references = defaultdict(list)
definitions = defaultdict(set) definitions = defaultdict(set)
@ -295,7 +306,7 @@ class RepoMap:
idents = set(defines.keys()).intersection(set(references.keys())) idents = set(defines.keys()).intersection(set(references.keys()))
G = nx.MultiDiGraph() G = self.nx.MultiDiGraph()
for ident in idents: for ident in idents:
definers = defines[ident] definers = defines[ident]
@ -326,7 +337,7 @@ class RepoMap:
pers_args = dict() pers_args = dict()
try: try:
ranked = nx.pagerank(G, weight="weight", **pers_args) ranked = self.nx.pagerank(G, weight="weight", **pers_args)
except ZeroDivisionError: except ZeroDivisionError:
return [] return []