From b5cd5f0e237dc1439f7d13be91f6b31cc612e37e Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Wed, 3 Jul 2024 13:55:29 -0300 Subject: [PATCH] Use a thread to import slow modules in the background --- aider/main.py | 14 ++++++++++++++ aider/repomap.py | 15 +++------------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/aider/main.py b/aider/main.py index a7bdf4932..25e417e74 100644 --- a/aider/main.py +++ b/aider/main.py @@ -2,6 +2,7 @@ import configparser import os import re import sys +import threading from pathlib import Path import git @@ -536,6 +537,8 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F return 1 return + threading.Thread(target=load_slow_imports).start() + while True: try: coder.run() @@ -545,6 +548,17 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F coder.show_announcements() +def load_slow_imports(): + # These imports are deferred in various ways to + # improve startup time. + # This func is called in a thread to load them in the background + # while we wait for the user to type their first message. + import httpx # noqa: F401 + import litellm # noqa: F401 + import networkx # noqa: F401 + import numpy # noqa: F401 + + if __name__ == "__main__": status = main() sys.exit(status) diff --git a/aider/repomap.py b/aider/repomap.py index e7b8cf93b..554286ad9 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -58,15 +58,6 @@ class RepoMap: self.token_count = main_model.token_count self.repo_content_prefix = repo_content_prefix - def import_nx(self): - """Import nx as needed, to avoid 250ms on launch""" - - if self.nx: - return - import networkx as nx - - self.nx = nx - def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None): if self.max_map_tokens <= 0: return @@ -239,7 +230,7 @@ class RepoMap: ) def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents): - self.import_nx() + import networkx as nx defines = defaultdict(set) references = defaultdict(list) @@ -306,7 +297,7 @@ class RepoMap: idents = set(defines.keys()).intersection(set(references.keys())) - G = self.nx.MultiDiGraph() + G = nx.MultiDiGraph() for ident in idents: definers = defines[ident] @@ -337,7 +328,7 @@ class RepoMap: pers_args = dict() try: - ranked = self.nx.pagerank(G, weight="weight", **pers_args) + ranked = nx.pagerank(G, weight="weight", **pers_args) except ZeroDivisionError: return []