refactored into get_requirements

This commit is contained in:
Paul Gauthier 2024-07-09 16:39:54 +01:00
parent 2af9876b76
commit e307be1a9c
3 changed files with 42 additions and 37 deletions

View file

@ -9,5 +9,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
pip install --no-cache-dir /aider && \
rm -rf /aider
WORKDIR /app
ENTRYPOINT ["aider"]

View file

@ -2,8 +2,8 @@
# pip-compile requirements.in --upgrade
#
# Install the cpu-only version of torch
--extra-index-url https://download.pytorch.org/whl/cpu
# To install with the CPU version of torch, because the GPU versions are huge
---extra-index-url https://download.pytorch.org/whl/cpu
configargparse
GitPython

View file

@ -1,43 +1,38 @@
import re
import subprocess
import sys
from tempfile import TemporaryDirectory
from setuptools import find_packages, setup
with open("requirements.txt") as f:
requirements = f.read().splitlines()
# Find the torch requirement and remove it from the list
torch = next((req for req in requirements if req.startswith("torch==")), None)
if torch:
requirements.remove(torch)
else:
torch = "torch==2.2.2" # Fallback if not found in requirements.txt
from aider import __version__
from aider.help_pats import exclude_website_pats
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
long_description = re.sub(r"\n!\[.*\]\(.*\)", "", long_description)
# long_description = re.sub(r"\n- \[.*\]\(.*\)", "", long_description)
# Debug: Print discovered packages
packages = find_packages(exclude=["benchmark"]) + ["aider.website"]
print("Discovered packages:", packages)
pytorch_url = None
# Find the torch requirement and replace it with the CPU only version,
# because the GPU versions are huge
def get_requirements():
with open("requirements.txt") as f:
requirements = f.read().splitlines()
requirements = [line for line in requirements if not line.startswith("---extra-index-url")]
torch = next((req for req in requirements if req.startswith("torch==")), None)
if not torch:
return requirements
pytorch_url = None
with TemporaryDirectory(prefix="pytorch_download_") as temp_dir:
cmd = [
sys.executable,
"-m",
"pip",
"download",
"install",
torch,
"--no-deps",
"--dest",
temp_dir,
"--dry-run",
# "--no-cache-dir",
# "--dest",
# temp_dir,
"--index-url",
"https://download.pytorch.org/whl/cpu",
]
@ -45,30 +40,39 @@ with TemporaryDirectory(prefix="pytorch_download_") as temp_dir:
try:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
for line in process.stdout:
print(line, end='') # Print each line of output
url_match = re.search(r"Downloading (https://download\.pytorch\.org/[^\s]+\.whl)", line)
if url_match:
pytorch_url = url_match.group(1)
url_match = re.search(r"Using cached (https://download\.pytorch\.org/[^\s]+\.whl)", line)
if url_match:
pytorch_url = url_match.group(1)
if pytorch_url:
print(f"PyTorch URL: {pytorch_url}")
process.terminate() # Terminate the subprocess
break
process.wait() # Wait for the process to finish
except subprocess.CalledProcessError as e:
print(f"Error running pip download: {e}")
if pytorch_url:
# print(pytorch_url)
# sys.exit()
if pytorch_url:
requirements.remove(torch)
requirements = [f"torch @ {pytorch_url}"] + requirements
else:
print("PyTorch URL not found in the output")
requirements = [torch] + requirements
#print(requirements)
return requirements
#sys.exit()
requirements = get_requirements()
# README
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
long_description = re.sub(r"\n!\[.*\]\(.*\)", "", long_description)
# long_description = re.sub(r"\n- \[.*\]\(.*\)", "", long_description)
# Discover packages, plus the website
packages = find_packages(exclude=["benchmark"]) + ["aider.website"]
print("Discovered packages:", packages)
setup(
name="aider-chat",