49 lines
1.7 KiB
Bash
49 lines
1.7 KiB
Bash
#!/usr/bin/env bash
|
|
set -euo pipefail
|
|
|
|
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
|
cd "$ROOT_DIR"
|
|
|
|
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
|
USE_SYSTEM_SITE_PACKAGES="${USE_SYSTEM_SITE_PACKAGES:-1}"
|
|
SKIP_TORCH_INSTALL="${SKIP_TORCH_INSTALL:-1}"
|
|
|
|
VENV_ARGS=()
|
|
if [ "$USE_SYSTEM_SITE_PACKAGES" = "1" ]; then
|
|
VENV_ARGS+=(--system-site-packages)
|
|
fi
|
|
|
|
"$PYTHON_BIN" -m venv "${VENV_ARGS[@]}" .venv
|
|
source .venv/bin/activate
|
|
|
|
python -m pip install --upgrade pip setuptools wheel
|
|
|
|
# Capture system torch versions before venv installs override them
|
|
TORCH_VERSION="$($PYTHON_BIN -c 'import torch; print(torch.__version__)' 2>/dev/null || true)"
|
|
TV_VERSION="$($PYTHON_BIN -c 'import torchvision; print(torchvision.__version__)' 2>/dev/null || true)"
|
|
|
|
if [ "$SKIP_TORCH_INSTALL" = "1" ]; then
|
|
filtered_requirements="$(mktemp)"
|
|
grep -Ev '^(torch|torchvision)([<>=!~].*)?$' requirements.txt > "$filtered_requirements"
|
|
|
|
# Pin system torch/torchvision so transitive deps (e.g. facenet-pytorch) can't downgrade them
|
|
constraints_file="$(mktemp)"
|
|
[ -n "$TORCH_VERSION" ] && echo "torch==$TORCH_VERSION" >> "$constraints_file"
|
|
[ -n "$TV_VERSION" ] && echo "torchvision==$TV_VERSION" >> "$constraints_file"
|
|
python -m pip install -c "$constraints_file" -r "$filtered_requirements"
|
|
rm -f "$filtered_requirements" "$constraints_file"
|
|
else
|
|
python -m pip install -r requirements.txt
|
|
fi
|
|
|
|
mkdir -p classifier/outputs/logs classifier/outputs/models classifier/outputs/analysis classifier/outputs/figures classifier/outputs/pipeline
|
|
|
|
python - <<'PY'
|
|
try:
|
|
import torch
|
|
print(f"torch={torch.__version__} cuda_available={torch.cuda.is_available()}")
|
|
except Exception as exc:
|
|
print(f"torch check failed: {exc}")
|
|
PY
|
|
|