import shlex import subprocess import time from dataclasses import dataclass from pathlib import Path # Error raised when a remote SSH or rsync command fails class RemoteCommandError(RuntimeError): pass # Connection details for an SSH target (Vast instance) @dataclass(slots=True) class SshTarget: user: str host: str port: int private_key: Path # user@host string used by SSH and rsync @property def destination(self) -> str: return f"{self.user}@{self.host}" # Base SSH flags: identity, port, keep-alive, and auto-accept host keys def ssh_base_command(self) -> list[str]: return [ "ssh", "-i", str(self.private_key), "-p", str(self.port), "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=accept-new", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=10", "-o", "TCPKeepAlive=yes", self.destination, ] # SSH flags formatted as a string for rsync's -e option def rsync_ssh_opts(self) -> str: return " ".join(self.ssh_base_command()[:-1]) # ── low-level subprocess helpers ───────────────────────────────────── # Run a local subprocess, raising RemoteCommandError on non-zero exit def _run(command: list[str]) -> None: result = subprocess.run(command, check=False, stderr=subprocess.PIPE, text=True) if result.returncode != 0: snippet = result.stderr.strip()[-500:] if result.stderr else "" msg = f"Command failed ({result.returncode}): {' '.join(command)}" raise RemoteCommandError(f"{msg}\n{snippet}" if snippet else msg) # ── SSH ─────────────────────────────────────────────────────────────── # Repeatedly attempt an SSH connection until the target accepts it or the deadline passes def wait_for_ssh(target: SshTarget, *, timeout_seconds: int, poll_interval_seconds: int) -> None: deadline = time.time() + timeout_seconds while time.time() < deadline: result = subprocess.run( target.ssh_base_command() + ["true"], check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) if result.returncode == 0: return remaining = int(deadline - time.time()) print(f" SSH not ready yet, retrying... ({remaining}s remaining)") time.sleep(poll_interval_seconds) raise RemoteCommandError( f"Timed out waiting for SSH on {target.host}:{target.port} after {timeout_seconds}s." ) # Execute a single command on the remote host via SSH def ssh(target: SshTarget, command: str) -> None: _run(target.ssh_base_command() + [command]) # Execute a remote command and return its stdout def ssh_output(target: SshTarget, command: str) -> str: result = subprocess.run( target.ssh_base_command() + [command], check=False, capture_output=True, text=True, ) if result.returncode != 0: raise RemoteCommandError( f"Command failed ({result.returncode}): {command}\n{result.stderr.strip()}" ) return result.stdout # ── detached process management ─────────────────────────────────────── # Launch a remote command under nohup so it survives SSH drops; return PID def run_detached(target: SshTarget, command: str, log_path: str) -> int: inner = f"nohup bash -c {shlex.quote(command)} > {shlex.quote(log_path)} 2>&1 & echo $!" output = ssh_output(target, inner) return int(output.strip().splitlines()[-1]) # Return True if the remote process is running # Raise RemoteCommandError on SSH failure (exit 255) so callers can apply # reconnect logic and only return False when kill -0 confirms the PID is gone def is_process_alive(target: SshTarget, pid: int) -> bool: try: ssh(target, f"kill -0 {pid}") return True except RemoteCommandError as exc: if "Command failed (255)" in str(exc): raise # SSH failure, let caller handle return False # process not found # Read a remote file's contents; returns None if the file does not exist def read_remote_file(target: SshTarget, path: str) -> str | None: return ssh_output(target, f"cat {shlex.quote(path)} 2>/dev/null || true").strip() or None # Open a persistent tail -f on a remote file for real-time streaming def tail_log(target: SshTarget, log_path: str) -> subprocess.Popen: cmd = target.ssh_base_command() + [f"tail -f -n +1 {shlex.quote(log_path)}"] return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True) # ── rsync ───────────────────────────────────────────────────────────── # Upload the entire project directory, respecting the exclude list def rsync_project(local_root: Path, target: SshTarget, remote_root: str, exclude_file: Path) -> None: _run([ "rsync", "-az", "--info=progress2", "--delete", "--exclude-from", str(exclude_file), "-e", target.rsync_ssh_opts(), f"{local_root}/", f"{target.destination}:{remote_root}/", ]) # Upload a single file to a remote directory def rsync_file(local_file: Path, target: SshTarget, remote_dir: str) -> None: _run([ "rsync", "-az", "--partial", "--info=progress2", "-e", target.rsync_ssh_opts(), str(local_file), f"{target.destination}:{remote_dir}/", ]) # Download a remote directory, verify file count, and resume partial transfers def fetch_directory(target: SshTarget, remote_dir: str, local_dir: Path) -> None: local_dir.mkdir(parents=True, exist_ok=True) # Count remote files first so we can verify completeness after transfer remote_count = int(ssh_output( target, f"find {remote_dir.rstrip('/')} -type f 2>/dev/null | wc -l" ).strip()) _run([ "rsync", "-az", "--partial", "--append-verify", "--info=progress2", "-e", target.rsync_ssh_opts(), f"{target.destination}:{remote_dir.rstrip('/')}/", f"{local_dir}/", ]) local_count = sum(1 for p in local_dir.rglob("*") if p.is_file()) # Raise if fewer files arrived than expected (corrupted / interrupted transfer) if local_count < remote_count: raise RemoteCommandError( f"Download incomplete: got {local_count} files, expected {remote_count}. " f"Remote: {remote_dir} Local: {local_dir}" ) print(f" Verified: {local_count} files downloaded successfully") # ── misc ────────────────────────────────────────────────────────────── # Quote and join shell arguments for safe remote execution def shell_join(parts: list[str]) -> str: return " ".join(shlex.quote(p) for p in parts)