From e307be1a9c9b5e5fdfa148a79370ecae56496bec Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Tue, 9 Jul 2024 16:39:54 +0100 Subject: [PATCH] refactored into get_requirements --- docker/Dockerfile | 1 + requirements.in | 4 +-- setup.py | 74 +++++++++++++++++++++++++---------------------- 3 files changed, 42 insertions(+), 37 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 10ae6c530..c3299903e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -9,5 +9,6 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* && \ pip install --no-cache-dir /aider && \ rm -rf /aider + WORKDIR /app ENTRYPOINT ["aider"] diff --git a/requirements.in b/requirements.in index 4c2bde243..21e6ef197 100644 --- a/requirements.in +++ b/requirements.in @@ -2,8 +2,8 @@ # pip-compile requirements.in --upgrade # -# Install the cpu-only version of torch ---extra-index-url https://download.pytorch.org/whl/cpu +# To install with the CPU version of torch, because the GPU versions are huge +---extra-index-url https://download.pytorch.org/whl/cpu configargparse GitPython diff --git a/setup.py b/setup.py index 6aec90c89..b629300c9 100644 --- a/setup.py +++ b/setup.py @@ -1,43 +1,38 @@ import re import subprocess import sys -from tempfile import TemporaryDirectory from setuptools import find_packages, setup -with open("requirements.txt") as f: - requirements = f.read().splitlines() - -# Find the torch requirement and remove it from the list -torch = next((req for req in requirements if req.startswith("torch==")), None) -if torch: - requirements.remove(torch) -else: - torch = "torch==2.2.2" # Fallback if not found in requirements.txt - from aider import __version__ from aider.help_pats import exclude_website_pats -with open("README.md", "r", encoding="utf-8") as f: - long_description = f.read() - long_description = re.sub(r"\n!\[.*\]\(.*\)", "", long_description) - # long_description = re.sub(r"\n- \[.*\]\(.*\)", "", long_description) -# Debug: Print discovered packages -packages = find_packages(exclude=["benchmark"]) + ["aider.website"] -print("Discovered packages:", packages) -pytorch_url = None +# Find the torch requirement and replace it with the CPU only version, +# because the GPU versions are huge +def get_requirements(): + with open("requirements.txt") as f: + requirements = f.read().splitlines() + + requirements = [line for line in requirements if not line.startswith("---extra-index-url")] + + torch = next((req for req in requirements if req.startswith("torch==")), None) + if not torch: + return requirements + + pytorch_url = None -with TemporaryDirectory(prefix="pytorch_download_") as temp_dir: cmd = [ sys.executable, "-m", "pip", - "download", + "install", torch, "--no-deps", - "--dest", - temp_dir, + "--dry-run", + # "--no-cache-dir", + # "--dest", + # temp_dir, "--index-url", "https://download.pytorch.org/whl/cpu", ] @@ -45,30 +40,39 @@ with TemporaryDirectory(prefix="pytorch_download_") as temp_dir: try: process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) for line in process.stdout: - print(line, end='') # Print each line of output url_match = re.search(r"Downloading (https://download\.pytorch\.org/[^\s]+\.whl)", line) if url_match: pytorch_url = url_match.group(1) - url_match = re.search(r"Using cached (https://download\.pytorch\.org/[^\s]+\.whl)", line) - if url_match: - pytorch_url = url_match.group(1) + if pytorch_url: - print(f"PyTorch URL: {pytorch_url}") process.terminate() # Terminate the subprocess break + process.wait() # Wait for the process to finish except subprocess.CalledProcessError as e: print(f"Error running pip download: {e}") -if pytorch_url: - requirements = [f"torch @ {pytorch_url}"] + requirements -else: - print("PyTorch URL not found in the output") - requirements = [torch] + requirements + # print(pytorch_url) + # sys.exit() -#print(requirements) + if pytorch_url: + requirements.remove(torch) + requirements = [f"torch @ {pytorch_url}"] + requirements -#sys.exit() + return requirements + + +requirements = get_requirements() + +# README +with open("README.md", "r", encoding="utf-8") as f: + long_description = f.read() + long_description = re.sub(r"\n!\[.*\]\(.*\)", "", long_description) + # long_description = re.sub(r"\n- \[.*\]\(.*\)", "", long_description) + +# Discover packages, plus the website +packages = find_packages(exclude=["benchmark"]) + ["aider.website"] +print("Discovered packages:", packages) setup( name="aider-chat",