181 lines
7.0 KiB
Python
181 lines
7.0 KiB
Python
import shlex
|
|
import subprocess
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
|
|
# Error raised when a remote SSH or rsync command fails
|
|
class RemoteCommandError(RuntimeError):
|
|
pass
|
|
|
|
|
|
# Connection details for an SSH target (Vast instance)
|
|
@dataclass(slots=True)
|
|
class SshTarget:
|
|
user: str
|
|
host: str
|
|
port: int
|
|
private_key: Path
|
|
|
|
# user@host string used by SSH and rsync
|
|
@property
|
|
def destination(self) -> str:
|
|
return f"{self.user}@{self.host}"
|
|
|
|
# Base SSH flags: identity, port, keep-alive, and auto-accept host keys
|
|
def ssh_base_command(self) -> list[str]:
|
|
return [
|
|
"ssh",
|
|
"-i", str(self.private_key),
|
|
"-p", str(self.port),
|
|
"-o", "BatchMode=yes",
|
|
"-o", "StrictHostKeyChecking=accept-new",
|
|
"-o", "ServerAliveInterval=30",
|
|
"-o", "ServerAliveCountMax=10",
|
|
"-o", "TCPKeepAlive=yes",
|
|
self.destination,
|
|
]
|
|
|
|
# SSH flags formatted as a string for rsync's -e option
|
|
def rsync_ssh_opts(self) -> str:
|
|
return " ".join(self.ssh_base_command()[:-1])
|
|
|
|
|
|
# ── low-level subprocess helpers ─────────────────────────────────────
|
|
|
|
# Run a local subprocess, raising RemoteCommandError on non-zero exit
|
|
def _run(command: list[str]) -> None:
|
|
result = subprocess.run(command, check=False, stderr=subprocess.PIPE, text=True)
|
|
if result.returncode != 0:
|
|
snippet = result.stderr.strip()[-500:] if result.stderr else ""
|
|
msg = f"Command failed ({result.returncode}): {' '.join(command)}"
|
|
raise RemoteCommandError(f"{msg}\n{snippet}" if snippet else msg)
|
|
|
|
|
|
# ── SSH ───────────────────────────────────────────────────────────────
|
|
|
|
# Repeatedly attempt an SSH connection until the target accepts it or the deadline passes
|
|
def wait_for_ssh(target: SshTarget, *, timeout_seconds: int, poll_interval_seconds: int) -> None:
|
|
deadline = time.time() + timeout_seconds
|
|
while time.time() < deadline:
|
|
result = subprocess.run(
|
|
target.ssh_base_command() + ["true"],
|
|
check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
|
)
|
|
if result.returncode == 0:
|
|
return
|
|
remaining = int(deadline - time.time())
|
|
print(f" SSH not ready yet, retrying... ({remaining}s remaining)")
|
|
time.sleep(poll_interval_seconds)
|
|
raise RemoteCommandError(
|
|
f"Timed out waiting for SSH on {target.host}:{target.port} after {timeout_seconds}s."
|
|
)
|
|
|
|
|
|
# Execute a single command on the remote host via SSH
|
|
def ssh(target: SshTarget, command: str) -> None:
|
|
_run(target.ssh_base_command() + [command])
|
|
|
|
|
|
# Execute a remote command and return its stdout
|
|
def ssh_output(target: SshTarget, command: str) -> str:
|
|
result = subprocess.run(
|
|
target.ssh_base_command() + [command],
|
|
check=False, capture_output=True, text=True,
|
|
)
|
|
if result.returncode != 0:
|
|
raise RemoteCommandError(
|
|
f"Command failed ({result.returncode}): {command}\n{result.stderr.strip()}"
|
|
)
|
|
return result.stdout
|
|
|
|
|
|
# ── detached process management ───────────────────────────────────────
|
|
|
|
# Launch a remote command under nohup so it survives SSH drops; return PID
|
|
def run_detached(target: SshTarget, command: str, log_path: str) -> int:
|
|
inner = f"nohup bash -c {shlex.quote(command)} > {shlex.quote(log_path)} 2>&1 & echo $!"
|
|
output = ssh_output(target, inner)
|
|
return int(output.strip().splitlines()[-1])
|
|
|
|
|
|
# Return True if the remote process is running
|
|
# Raise RemoteCommandError on SSH failure (exit 255) so callers can apply
|
|
# reconnect logic and only return False when kill -0 confirms the PID is gone
|
|
def is_process_alive(target: SshTarget, pid: int) -> bool:
|
|
try:
|
|
ssh(target, f"kill -0 {pid}")
|
|
return True
|
|
except RemoteCommandError as exc:
|
|
if "Command failed (255)" in str(exc):
|
|
raise # SSH failure, let caller handle
|
|
return False # process not found
|
|
|
|
|
|
# Read a remote file's contents; returns None if the file does not exist
|
|
def read_remote_file(target: SshTarget, path: str) -> str | None:
|
|
return ssh_output(target, f"cat {shlex.quote(path)} 2>/dev/null || true").strip() or None
|
|
|
|
|
|
# Open a persistent tail -f on a remote file for real-time streaming
|
|
def tail_log(target: SshTarget, log_path: str) -> subprocess.Popen:
|
|
cmd = target.ssh_base_command() + [f"tail -f -n +1 {shlex.quote(log_path)}"]
|
|
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True)
|
|
|
|
|
|
# ── rsync ─────────────────────────────────────────────────────────────
|
|
|
|
# Upload the entire project directory, respecting the exclude list
|
|
def rsync_project(local_root: Path, target: SshTarget, remote_root: str, exclude_file: Path) -> None:
|
|
_run([
|
|
"rsync", "-az", "--info=progress2", "--delete",
|
|
"--exclude-from", str(exclude_file),
|
|
"-e", target.rsync_ssh_opts(),
|
|
f"{local_root}/",
|
|
f"{target.destination}:{remote_root}/",
|
|
])
|
|
|
|
|
|
# Upload a single file to a remote directory
|
|
def rsync_file(local_file: Path, target: SshTarget, remote_dir: str) -> None:
|
|
_run([
|
|
"rsync", "-az", "--partial", "--info=progress2",
|
|
"-e", target.rsync_ssh_opts(),
|
|
str(local_file),
|
|
f"{target.destination}:{remote_dir}/",
|
|
])
|
|
|
|
|
|
# Download a remote directory, verify file count, and resume partial transfers
|
|
def fetch_directory(target: SshTarget, remote_dir: str, local_dir: Path) -> None:
|
|
local_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Count remote files first so we can verify completeness after transfer
|
|
remote_count = int(ssh_output(
|
|
target, f"find {remote_dir.rstrip('/')} -type f 2>/dev/null | wc -l"
|
|
).strip())
|
|
|
|
_run([
|
|
"rsync", "-az", "--partial", "--append-verify", "--info=progress2",
|
|
"-e", target.rsync_ssh_opts(),
|
|
f"{target.destination}:{remote_dir.rstrip('/')}/",
|
|
f"{local_dir}/",
|
|
])
|
|
|
|
local_count = sum(1 for p in local_dir.rglob("*") if p.is_file())
|
|
# Raise if fewer files arrived than expected (corrupted / interrupted transfer)
|
|
if local_count < remote_count:
|
|
raise RemoteCommandError(
|
|
f"Download incomplete: got {local_count} files, expected {remote_count}. "
|
|
f"Remote: {remote_dir} Local: {local_dir}"
|
|
)
|
|
print(f" Verified: {local_count} files downloaded successfully")
|
|
|
|
|
|
# ── misc ──────────────────────────────────────────────────────────────
|
|
|
|
# Quote and join shell arguments for safe remote execution
|
|
def shell_join(parts: list[str]) -> str:
|
|
return " ".join(shlex.quote(p) for p in parts)
|