Use a thread to import slow modules in the background

This commit is contained in:
Paul Gauthier 2024-07-03 13:55:29 -03:00
parent b3f7f0a250
commit b5cd5f0e23
2 changed files with 17 additions and 12 deletions

View file

@ -2,6 +2,7 @@ import configparser
import os
import re
import sys
import threading
from pathlib import Path
import git
@ -536,6 +537,8 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
return 1
return
threading.Thread(target=load_slow_imports).start()
while True:
try:
coder.run()
@ -545,6 +548,17 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
coder.show_announcements()
def load_slow_imports():
# These imports are deferred in various ways to
# improve startup time.
# This func is called in a thread to load them in the background
# while we wait for the user to type their first message.
import httpx # noqa: F401
import litellm # noqa: F401
import networkx # noqa: F401
import numpy # noqa: F401
if __name__ == "__main__":
status = main()
sys.exit(status)

View file

@ -58,15 +58,6 @@ class RepoMap:
self.token_count = main_model.token_count
self.repo_content_prefix = repo_content_prefix
def import_nx(self):
"""Import nx as needed, to avoid 250ms on launch"""
if self.nx:
return
import networkx as nx
self.nx = nx
def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None):
if self.max_map_tokens <= 0:
return
@ -239,7 +230,7 @@ class RepoMap:
)
def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents):
self.import_nx()
import networkx as nx
defines = defaultdict(set)
references = defaultdict(list)
@ -306,7 +297,7 @@ class RepoMap:
idents = set(defines.keys()).intersection(set(references.keys()))
G = self.nx.MultiDiGraph()
G = nx.MultiDiGraph()
for ident in idents:
definers = defines[ident]
@ -337,7 +328,7 @@ class RepoMap:
pers_args = dict()
try:
ranked = self.nx.pagerank(G, weight="weight", **pers_args)
ranked = nx.pagerank(G, weight="weight", **pers_args)
except ZeroDivisionError:
return []