refactored into get_requirements

This commit is contained in:
Paul Gauthier 2024-07-09 16:39:54 +01:00
parent 2af9876b76
commit e307be1a9c
3 changed files with 42 additions and 37 deletions

View file

@ -9,5 +9,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \ rm -rf /var/lib/apt/lists/* && \
pip install --no-cache-dir /aider && \ pip install --no-cache-dir /aider && \
rm -rf /aider rm -rf /aider
WORKDIR /app WORKDIR /app
ENTRYPOINT ["aider"] ENTRYPOINT ["aider"]

View file

@ -2,8 +2,8 @@
# pip-compile requirements.in --upgrade # pip-compile requirements.in --upgrade
# #
# Install the cpu-only version of torch # To install with the CPU version of torch, because the GPU versions are huge
--extra-index-url https://download.pytorch.org/whl/cpu ---extra-index-url https://download.pytorch.org/whl/cpu
configargparse configargparse
GitPython GitPython

View file

@ -1,43 +1,38 @@
import re import re
import subprocess import subprocess
import sys import sys
from tempfile import TemporaryDirectory
from setuptools import find_packages, setup from setuptools import find_packages, setup
with open("requirements.txt") as f:
requirements = f.read().splitlines()
# Find the torch requirement and remove it from the list
torch = next((req for req in requirements if req.startswith("torch==")), None)
if torch:
requirements.remove(torch)
else:
torch = "torch==2.2.2" # Fallback if not found in requirements.txt
from aider import __version__ from aider import __version__
from aider.help_pats import exclude_website_pats from aider.help_pats import exclude_website_pats
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
long_description = re.sub(r"\n!\[.*\]\(.*\)", "", long_description)
# long_description = re.sub(r"\n- \[.*\]\(.*\)", "", long_description)
# Debug: Print discovered packages # Find the torch requirement and replace it with the CPU only version,
packages = find_packages(exclude=["benchmark"]) + ["aider.website"] # because the GPU versions are huge
print("Discovered packages:", packages) def get_requirements():
pytorch_url = None with open("requirements.txt") as f:
requirements = f.read().splitlines()
requirements = [line for line in requirements if not line.startswith("---extra-index-url")]
torch = next((req for req in requirements if req.startswith("torch==")), None)
if not torch:
return requirements
pytorch_url = None
with TemporaryDirectory(prefix="pytorch_download_") as temp_dir:
cmd = [ cmd = [
sys.executable, sys.executable,
"-m", "-m",
"pip", "pip",
"download", "install",
torch, torch,
"--no-deps", "--no-deps",
"--dest", "--dry-run",
temp_dir, # "--no-cache-dir",
# "--dest",
# temp_dir,
"--index-url", "--index-url",
"https://download.pytorch.org/whl/cpu", "https://download.pytorch.org/whl/cpu",
] ]
@ -45,30 +40,39 @@ with TemporaryDirectory(prefix="pytorch_download_") as temp_dir:
try: try:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
for line in process.stdout: for line in process.stdout:
print(line, end='') # Print each line of output
url_match = re.search(r"Downloading (https://download\.pytorch\.org/[^\s]+\.whl)", line) url_match = re.search(r"Downloading (https://download\.pytorch\.org/[^\s]+\.whl)", line)
if url_match: if url_match:
pytorch_url = url_match.group(1) pytorch_url = url_match.group(1)
url_match = re.search(r"Using cached (https://download\.pytorch\.org/[^\s]+\.whl)", line)
if url_match:
pytorch_url = url_match.group(1)
if pytorch_url: if pytorch_url:
print(f"PyTorch URL: {pytorch_url}")
process.terminate() # Terminate the subprocess process.terminate() # Terminate the subprocess
break break
process.wait() # Wait for the process to finish process.wait() # Wait for the process to finish
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"Error running pip download: {e}") print(f"Error running pip download: {e}")
if pytorch_url: # print(pytorch_url)
requirements = [f"torch @ {pytorch_url}"] + requirements # sys.exit()
else:
print("PyTorch URL not found in the output")
requirements = [torch] + requirements
#print(requirements) if pytorch_url:
requirements.remove(torch)
requirements = [f"torch @ {pytorch_url}"] + requirements
#sys.exit() return requirements
requirements = get_requirements()
# README
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
long_description = re.sub(r"\n!\[.*\]\(.*\)", "", long_description)
# long_description = re.sub(r"\n- \[.*\]\(.*\)", "", long_description)
# Discover packages, plus the website
packages = find_packages(exclude=["benchmark"]) + ["aider.website"]
print("Discovered packages:", packages)
setup( setup(
name="aider-chat", name="aider-chat",