Files
DRL_PROJ/pipeline/remote.py
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

181 lines
7.0 KiB
Python

import shlex
import subprocess
import time
from dataclasses import dataclass
from pathlib import Path
# Error raised when a remote SSH or rsync command fails
class RemoteCommandError(RuntimeError):
pass
# Connection details for an SSH target (Vast instance)
@dataclass(slots=True)
class SshTarget:
user: str
host: str
port: int
private_key: Path
# user@host string used by SSH and rsync
@property
def destination(self) -> str:
return f"{self.user}@{self.host}"
# Base SSH flags: identity, port, keep-alive, and auto-accept host keys
def ssh_base_command(self) -> list[str]:
return [
"ssh",
"-i", str(self.private_key),
"-p", str(self.port),
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=accept-new",
"-o", "ServerAliveInterval=30",
"-o", "ServerAliveCountMax=10",
"-o", "TCPKeepAlive=yes",
self.destination,
]
# SSH flags formatted as a string for rsync's -e option
def rsync_ssh_opts(self) -> str:
return " ".join(self.ssh_base_command()[:-1])
# ── low-level subprocess helpers ─────────────────────────────────────
# Run a local subprocess, raising RemoteCommandError on non-zero exit
def _run(command: list[str]) -> None:
result = subprocess.run(command, check=False, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
snippet = result.stderr.strip()[-500:] if result.stderr else ""
msg = f"Command failed ({result.returncode}): {' '.join(command)}"
raise RemoteCommandError(f"{msg}\n{snippet}" if snippet else msg)
# ── SSH ───────────────────────────────────────────────────────────────
# Repeatedly attempt an SSH connection until the target accepts it or the deadline passes
def wait_for_ssh(target: SshTarget, *, timeout_seconds: int, poll_interval_seconds: int) -> None:
deadline = time.time() + timeout_seconds
while time.time() < deadline:
result = subprocess.run(
target.ssh_base_command() + ["true"],
check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
)
if result.returncode == 0:
return
remaining = int(deadline - time.time())
print(f" SSH not ready yet, retrying... ({remaining}s remaining)")
time.sleep(poll_interval_seconds)
raise RemoteCommandError(
f"Timed out waiting for SSH on {target.host}:{target.port} after {timeout_seconds}s."
)
# Execute a single command on the remote host via SSH
def ssh(target: SshTarget, command: str) -> None:
_run(target.ssh_base_command() + [command])
# Execute a remote command and return its stdout
def ssh_output(target: SshTarget, command: str) -> str:
result = subprocess.run(
target.ssh_base_command() + [command],
check=False, capture_output=True, text=True,
)
if result.returncode != 0:
raise RemoteCommandError(
f"Command failed ({result.returncode}): {command}\n{result.stderr.strip()}"
)
return result.stdout
# ── detached process management ───────────────────────────────────────
# Launch a remote command under nohup so it survives SSH drops; return PID
def run_detached(target: SshTarget, command: str, log_path: str) -> int:
inner = f"nohup bash -c {shlex.quote(command)} > {shlex.quote(log_path)} 2>&1 & echo $!"
output = ssh_output(target, inner)
return int(output.strip().splitlines()[-1])
# Return True if the remote process is running
# Raise RemoteCommandError on SSH failure (exit 255) so callers can apply
# reconnect logic and only return False when kill -0 confirms the PID is gone
def is_process_alive(target: SshTarget, pid: int) -> bool:
try:
ssh(target, f"kill -0 {pid}")
return True
except RemoteCommandError as exc:
if "Command failed (255)" in str(exc):
raise # SSH failure, let caller handle
return False # process not found
# Read a remote file's contents; returns None if the file does not exist
def read_remote_file(target: SshTarget, path: str) -> str | None:
return ssh_output(target, f"cat {shlex.quote(path)} 2>/dev/null || true").strip() or None
# Open a persistent tail -f on a remote file for real-time streaming
def tail_log(target: SshTarget, log_path: str) -> subprocess.Popen:
cmd = target.ssh_base_command() + [f"tail -f -n +1 {shlex.quote(log_path)}"]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True)
# ── rsync ─────────────────────────────────────────────────────────────
# Upload the entire project directory, respecting the exclude list
def rsync_project(local_root: Path, target: SshTarget, remote_root: str, exclude_file: Path) -> None:
_run([
"rsync", "-az", "--info=progress2", "--delete",
"--exclude-from", str(exclude_file),
"-e", target.rsync_ssh_opts(),
f"{local_root}/",
f"{target.destination}:{remote_root}/",
])
# Upload a single file to a remote directory
def rsync_file(local_file: Path, target: SshTarget, remote_dir: str) -> None:
_run([
"rsync", "-az", "--partial", "--info=progress2",
"-e", target.rsync_ssh_opts(),
str(local_file),
f"{target.destination}:{remote_dir}/",
])
# Download a remote directory, verify file count, and resume partial transfers
def fetch_directory(target: SshTarget, remote_dir: str, local_dir: Path) -> None:
local_dir.mkdir(parents=True, exist_ok=True)
# Count remote files first so we can verify completeness after transfer
remote_count = int(ssh_output(
target, f"find {remote_dir.rstrip('/')} -type f 2>/dev/null | wc -l"
).strip())
_run([
"rsync", "-az", "--partial", "--append-verify", "--info=progress2",
"-e", target.rsync_ssh_opts(),
f"{target.destination}:{remote_dir.rstrip('/')}/",
f"{local_dir}/",
])
local_count = sum(1 for p in local_dir.rglob("*") if p.is_file())
# Raise if fewer files arrived than expected (corrupted / interrupted transfer)
if local_count < remote_count:
raise RemoteCommandError(
f"Download incomplete: got {local_count} files, expected {remote_count}. "
f"Remote: {remote_dir} Local: {local_dir}"
)
print(f" Verified: {local_count} files downloaded successfully")
# ── misc ──────────────────────────────────────────────────────────────
# Quote and join shell arguments for safe remote execution
def shell_join(parts: list[str]) -> str:
return " ".join(shlex.quote(p) for p in parts)