From 50adf7316975973fcfa57e6b7f3e2f2d50aa4f9e Mon Sep 17 00:00:00 2001 From: Vincent Caux-Brisebois Date: Mon, 27 Apr 2026 19:02:01 +0000 Subject: [PATCH 1/2] Adding qemu vm driver support with GPU pass-through Signed-off-by: Vincent Caux-Brisebois --- Cargo.lock | 14 + crates/openshell-cli/src/main.rs | 8 + crates/openshell-cli/src/run.rs | 4 + .../sandbox_create_lifecycle_integration.rs | 5 + .../openshell-driver-kubernetes/src/driver.rs | 1 + crates/openshell-driver-podman/src/driver.rs | 1 + crates/openshell-driver-vm/Cargo.toml | 3 + crates/openshell-driver-vm/README.md | 18 +- .../scripts/openshell-vm-sandbox-init.sh | 104 +- crates/openshell-driver-vm/src/driver.rs | 382 +++++- crates/openshell-driver-vm/src/gpu.rs | 316 +++++ crates/openshell-driver-vm/src/lib.rs | 3 +- crates/openshell-driver-vm/src/main.rs | 51 +- crates/openshell-driver-vm/src/runtime.rs | 548 +++++++- crates/openshell-driver-vm/start.sh | 31 +- crates/openshell-server/src/compute/mod.rs | 2 + crates/openshell-vfio/Cargo.toml | 23 + crates/openshell-vfio/src/lib.rs | 1110 +++++++++++++++++ proto/compute_driver.proto | 6 + proto/openshell.proto | 4 + tasks/vm.toml | 2 +- 21 files changed, 2533 insertions(+), 103 deletions(-) create mode 100644 crates/openshell-driver-vm/src/gpu.rs create mode 100644 crates/openshell-vfio/Cargo.toml create mode 100644 crates/openshell-vfio/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index e58168a8a..0e59eb64f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3210,8 +3210,11 @@ dependencies = [ "miette", "nix", "openshell-core", + "openshell-vfio", "polling", "prost-types", + "serde", + "serde_json", "tar", "tokio", "tokio-stream", @@ -3407,6 +3410,17 @@ dependencies = [ "url", ] +[[package]] +name = "openshell-vfio" +version = "0.0.0" +dependencies = [ + "serde", + "serde_json", + "tempfile", + "thiserror 2.0.18", + "tracing", +] + [[package]] name = "openshell-vm" version = "0.0.0" diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index df5fb0061..385399312 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1138,6 +1138,11 @@ enum SandboxCommands { #[arg(long)] gpu: bool, + /// Target a specific GPU by PCI address (e.g. "0000:2d:00.0") or index (e.g. "0", "1"). + /// Only valid with --gpu. When omitted with --gpu, the first available GPU is assigned. + #[arg(long, requires = "gpu")] + gpu_device: Option, + /// SSH destination for remote bootstrap (e.g., user@hostname). /// Only used when no cluster exists yet; ignored if a cluster is /// already active. @@ -2307,6 +2312,7 @@ async fn main() -> Result<()> { no_keep, editor, gpu, + gpu_device, remote, ssh_key, providers, @@ -2402,6 +2408,7 @@ async fn main() -> Result<()> { upload_spec.as_ref(), keep, gpu, + gpu_device.as_deref(), editor, remote.as_deref(), ssh_key.as_deref(), @@ -2425,6 +2432,7 @@ async fn main() -> Result<()> { upload_spec.as_ref(), keep, gpu, + gpu_device.as_deref(), editor, remote.as_deref(), ssh_key.as_deref(), diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 769c37748..21cb40b4e 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1923,6 +1923,7 @@ pub async fn sandbox_create_with_bootstrap( upload: Option<&(String, Option, bool)>, keep: bool, gpu: bool, + gpu_device: Option<&str>, editor: Option, remote: Option<&str>, ssh_key: Option<&str>, @@ -1954,6 +1955,7 @@ pub async fn sandbox_create_with_bootstrap( upload, keep, gpu, + gpu_device, editor, remote, ssh_key, @@ -2010,6 +2012,7 @@ pub async fn sandbox_create( upload: Option<&(String, Option, bool)>, keep: bool, gpu: bool, + gpu_device: Option<&str>, editor: Option, remote: Option<&str>, ssh_key: Option<&str>, @@ -2117,6 +2120,7 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { gpu: requested_gpu, + gpu_device: gpu_device.unwrap_or_default().to_string(), policy, providers: configured_providers, template, diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 79d482fdb..6b78dab9a 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -574,6 +574,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { None, None, None, + None, &[], None, None, @@ -615,6 +616,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -659,6 +661,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -703,6 +706,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { None, None, None, + None, &[], None, None, @@ -744,6 +748,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(8080)), diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 444e0f55d..a3f90457c 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -150,6 +150,7 @@ impl KubernetesComputeDriver { driver_version: openshell_core::VERSION.to_string(), default_image: self.config.default_image.clone(), supports_gpu: self.has_gpu_capacity().await.unwrap_or(false), + gpu_count: 0, }) } diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index e4b017002..dff95a532 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -161,6 +161,7 @@ impl PodmanComputeDriver { driver_version: openshell_core::VERSION.to_string(), default_image: self.config.default_image.clone(), supports_gpu, + gpu_count: 0, }) } diff --git a/crates/openshell-driver-vm/Cargo.toml b/crates/openshell-driver-vm/Cargo.toml index b4d92b0fc..04f4e9fc5 100644 --- a/crates/openshell-driver-vm/Cargo.toml +++ b/crates/openshell-driver-vm/Cargo.toml @@ -20,6 +20,7 @@ path = "src/main.rs" [dependencies] openshell-core = { path = "../openshell-core" } +openshell-vfio = { path = "../openshell-vfio" } tokio = { workspace = true } tonic = { workspace = true, features = ["transport"] } @@ -32,6 +33,8 @@ tracing = { workspace = true } tracing-subscriber = { workspace = true } miette = { workspace = true } url = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } libc = "0.2" libloading = "0.8" tar = "0.4" diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index 44f326c3b..39be02676 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -36,9 +36,17 @@ mise run gateway:vm ``` First run takes a few minutes while `mise run vm:setup` stages libkrun/libkrunfw/gvproxy and `mise run vm:rootfs -- --base` builds the embedded rootfs. Subsequent runs are cached. To keep the Unix socket path under macOS `SUN_LEN`, `mise run gateway:vm` and `start.sh` default the state dir to `/tmp/openshell-vm-driver-dev-$USER-port-$PORT/` (SQLite DB + per-sandbox rootfs + `compute-driver.sock`) unless `OPENSHELL_VM_DRIVER_STATE_DIR` is set. -The wrapper also prints the recommended gateway name (`vm-driver-port-$PORT` by default) plus the exact repo-local `scripts/bin/openshell gateway add` and `scripts/bin/openshell gateway select` commands to use from another terminal. This avoids accidentally hitting an older `openshell` binary elsewhere on your `PATH`. +The wrapper auto-registers the gateway with the CLI (`gateway destroy` + `gateway add`) so no manual registration step is needed. When running under `sudo`, it uses `sudo -u $SUDO_USER` for the registration so the config is written under the invoking user's home directory. Re-runs are idempotent. It also exports `OPENSHELL_DRIVER_DIR=$PWD/target/debug` before starting the gateway so local dev runs use the freshly built `openshell-driver-vm` instead of an older installed copy from `~/.local/libexec/openshell` or `/usr/local/libexec`. +For GPU passthrough (VFIO), pass `-- --gpu` and run with root privileges: + +```shell +sudo -E env "PATH=$PATH" mise run gateway:vm -- --gpu +``` + +See [`architecture/vm-gpu-sandbox-guide.md`](../../architecture/vm-gpu-sandbox-guide.md) for full GPU prerequisites and usage. + Override via environment: ```shell @@ -129,13 +137,11 @@ See [`openshell-gateway --help`](../openshell-server/src/cli.rs) for the full fl ## Verifying the gateway -In another terminal: +The gateway is auto-registered by `start.sh`. In another terminal: ```shell -export OPENSHELL_GATEWAY_URL=http://127.0.0.1:8080 -cargo run -p openshell-cli -- gateway register local --url $OPENSHELL_GATEWAY_URL --no-tls -cargo run -p openshell-cli -- sandbox create --name demo -cargo run -p openshell-cli -- sandbox connect demo +scripts/bin/openshell sandbox create --name demo +scripts/bin/openshell sandbox connect demo ``` First sandbox takes 10–30 seconds to boot (rootfs extraction + libkrun + guest init). Subsequent creates reuse the prepared sandbox rootfs. diff --git a/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh b/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh index e449003f9..5e3227d11 100644 --- a/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh +++ b/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh @@ -3,13 +3,35 @@ # SPDX-License-Identifier: Apache-2.0 # Minimal init for sandbox VMs. Runs as PID 1 inside the guest, mounts the -# essential filesystems, configures gvproxy networking when present, then -# execs the OpenShell sandbox supervisor. +# essential filesystems, configures networking (gvproxy DHCP or TAP static), +# optionally loads NVIDIA GPU drivers, then execs the OpenShell sandbox +# supervisor. set -euo pipefail +# Source QEMU-injected environment variables if present +if [ -f /srv/openshell-env.sh ]; then + source /srv/openshell-env.sh +fi + BOOT_START=$(date +%s%3N 2>/dev/null || date +%s) GVPROXY_GATEWAY_IP="192.168.127.1" +GATEWAY_IP="$GVPROXY_GATEWAY_IP" + +# Parse kernel cmdline for GPU and TAP networking parameters +GPU_ENABLED="${GPU_ENABLED:-false}" +VM_NET_IP="${VM_NET_IP:-}" +VM_NET_GW="${VM_NET_GW:-}" +VM_NET_DNS="${VM_NET_DNS:-}" + +for param in $(cat /proc/cmdline 2>/dev/null || true); do + case "$param" in + GPU_ENABLED=*) GPU_ENABLED="${param#GPU_ENABLED=}" ;; + VM_NET_IP=*) VM_NET_IP="${param#VM_NET_IP=}" ;; + VM_NET_GW=*) VM_NET_GW="${param#VM_NET_GW=}" ;; + VM_NET_DNS=*) VM_NET_DNS="${param#VM_NET_DNS=}" ;; + esac +done ts() { local now @@ -82,7 +104,7 @@ ensure_host_gateway_aliases() { : > "$hosts_tmp" fi - printf '%s host.openshell.internal\n' "$GVPROXY_GATEWAY_IP" >> "$hosts_tmp" + printf '%s host.openshell.internal\n' "$GATEWAY_IP" >> "$hosts_tmp" cat "$hosts_tmp" > /etc/hosts rm -f "$hosts_tmp" } @@ -107,7 +129,7 @@ rewrite_openshell_endpoint_if_needed() { return 0 fi - for candidate in host.openshell.internal host.containers.internal host.docker.internal "$GVPROXY_GATEWAY_IP"; do + for candidate in host.openshell.internal host.containers.internal host.docker.internal "$GATEWAY_IP"; do if [ "$candidate" = "$host" ]; then continue fi @@ -126,6 +148,47 @@ rewrite_openshell_endpoint_if_needed() { ts "WARNING: could not reach OpenShell endpoint ${host}:${port}" } +setup_gpu() { + ts "GPU_ENABLED=true — initializing GPU passthrough" + + if ! command -v modprobe >/dev/null 2>&1; then + ts "FATAL: modprobe not found; cannot load nvidia kernel modules" + return 1 + fi + + # Stage GSP firmware from virtiofs to tmpfs to avoid slow FUSE reads + # during module load. The kernel's firmware_class.path= cmdline param + # points here initially for early request_firmware calls. + if [ -d /lib/firmware/nvidia ]; then + ts "staging GPU firmware to tmpfs" + mkdir -p /run/firmware/nvidia + cp -a /lib/firmware/nvidia/* /run/firmware/nvidia/ 2>/dev/null || true + if [ -e /sys/module/firmware_class/parameters/path ]; then + echo /run/firmware > /sys/module/firmware_class/parameters/path + fi + fi + + ts "loading nvidia kernel modules" + modprobe nvidia || { ts "FATAL: modprobe nvidia failed"; return 1; } + modprobe nvidia_uvm 2>/dev/null || true + modprobe nvidia_modeset 2>/dev/null || true + + # Free the tmpfs firmware copy now that modules are loaded + rm -rf /run/firmware 2>/dev/null || true + + if command -v nvidia-smi >/dev/null 2>&1; then + ts "validating nvidia-smi" + if nvidia-smi; then + ts "GPU initialization successful" + else + ts "FATAL: nvidia-smi failed" + return 1 + fi + else + ts "WARNING: nvidia-smi not found in rootfs; skipping GPU validation" + fi +} + mount -t proc proc /proc 2>/dev/null & mount -t sysfs sysfs /sys 2>/dev/null & mount -t tmpfs tmpfs /tmp 2>/dev/null & @@ -146,7 +209,36 @@ chown sandbox:sandbox /sandbox 2>/dev/null || true hostname openshell-sandbox-vm 2>/dev/null || true ip link set lo up 2>/dev/null || true -if ip link show eth0 >/dev/null 2>&1; then +# GPU initialization (before networking so nvidia-smi output is visible early) +if [ "${GPU_ENABLED}" = "true" ]; then + setup_gpu || ts "WARNING: GPU init failed; continuing without GPU" +fi + +# Networking: use TAP static config if VM_NET_IP is set (QEMU path), +# otherwise fall back to gvproxy DHCP on eth0 (libkrun path). +if [ -n "${VM_NET_IP}" ] && [ -n "${VM_NET_GW}" ]; then + ts "configuring TAP networking (static ${VM_NET_IP} gw ${VM_NET_GW})" + GATEWAY_IP="${VM_NET_GW}" + + if ip link show eth0 >/dev/null 2>&1; then + ip link set eth0 up 2>/dev/null || true + ip addr add "${VM_NET_IP}/30" dev eth0 2>/dev/null || true + ip route add default via "${VM_NET_GW}" 2>/dev/null || true + elif ip link show ens3 >/dev/null 2>&1; then + ip link set ens3 up 2>/dev/null || true + ip addr add "${VM_NET_IP}/30" dev ens3 2>/dev/null || true + ip route add default via "${VM_NET_GW}" 2>/dev/null || true + fi + + if [ -n "${VM_NET_DNS}" ]; then + echo "nameserver ${VM_NET_DNS}" > /etc/resolv.conf + elif [ ! -s /etc/resolv.conf ]; then + echo "nameserver 8.8.8.8" > /etc/resolv.conf + echo "nameserver 8.8.4.4" >> /etc/resolv.conf + fi + + ensure_host_gateway_aliases +elif ip link show eth0 >/dev/null 2>&1; then ts "detected eth0 (gvproxy networking)" ip link set eth0 up 2>/dev/null || true @@ -193,7 +285,7 @@ DHCP_SCRIPT ensure_host_gateway_aliases else - ts "WARNING: eth0 not found; supervisor will start without guest egress" + ts "WARNING: no network interface found; supervisor will start without guest egress" fi export HOME=/sandbox diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index d649a585a..1ec53f4ed 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -1,6 +1,9 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use crate::gpu::{ + GpuInventory, SubnetAllocator, allocate_vsock_cid, mac_from_sandbox_id, tap_device_name, +}; use crate::rootfs::{extract_sandbox_rootfs_to, sandbox_guest_init_path}; use futures::Stream; use nix::errno::Errno; @@ -16,7 +19,9 @@ use openshell_core::proto::compute::v1::{ WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, }; +use openshell_vfio::SysfsRoot; use std::collections::{HashMap, HashSet}; +use std::net::Ipv4Addr; use std::os::unix::fs::PermissionsExt; use std::path::{Path, PathBuf}; use std::pin::Pin; @@ -62,6 +67,9 @@ pub struct VmDriverConfig { pub guest_tls_ca: Option, pub guest_tls_cert: Option, pub guest_tls_key: Option, + pub gpu_enabled: bool, + pub gpu_mem_mib: u32, + pub gpu_vcpus: u8, } impl Default for VmDriverConfig { @@ -79,6 +87,9 @@ impl Default for VmDriverConfig { guest_tls_ca: None, guest_tls_cert: None, guest_tls_key: None, + gpu_enabled: false, + gpu_mem_mib: 8192, + gpu_vcpus: 4, } } } @@ -167,14 +178,18 @@ struct SandboxRecord { snapshot: Sandbox, state_dir: PathBuf, process: Arc>, + gpu_bdf: Option, } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct VmDriver { config: VmDriverConfig, launcher_bin: PathBuf, registry: Arc>>, events: broadcast::Sender, + gpu_inventory: Option>>, + gpu_count: u32, + subnet_allocator: Arc>, } impl VmDriver { @@ -185,6 +200,11 @@ impl VmDriver { validate_openshell_endpoint(&config.openshell_endpoint)?; let _ = config.tls_paths()?; + #[cfg(target_os = "linux")] + if config.gpu_enabled { + check_gpu_privileges()?; + } + let state_root = config.state_dir.join("sandboxes"); tokio::fs::create_dir_all(&state_root) .await @@ -202,12 +222,30 @@ impl VmDriver { .map_err(|err| format!("failed to resolve vm driver executable: {err}"))? }; + let (gpu_inventory, gpu_count) = if config.gpu_enabled { + let sysfs = SysfsRoot::system(); + let inventory = GpuInventory::new(sysfs, &config.state_dir); + let count = inventory.gpu_count(); + tracing::info!(gpu_count = count, "GPU inventory initialized"); + (Some(Arc::new(std::sync::Mutex::new(inventory))), count) + } else { + (None, 0) + }; + + let subnet_allocator = Arc::new(std::sync::Mutex::new(SubnetAllocator::new( + Ipv4Addr::new(10, 0, 128, 0), + 17, + ))); + let (events, _) = broadcast::channel(WATCH_BUFFER); Ok(Self { config, launcher_bin, registry: Arc::new(Mutex::new(HashMap::new())), events, + gpu_inventory, + gpu_count, + subnet_allocator, }) } @@ -217,21 +255,26 @@ impl VmDriver { driver_name: DRIVER_NAME.to_string(), driver_version: openshell_core::VERSION.to_string(), default_image: String::new(), - supports_gpu: false, + supports_gpu: self.gpu_inventory.is_some(), + gpu_count: self.gpu_count, } } pub async fn validate_sandbox(&self, sandbox: &Sandbox) -> Result<(), Status> { - validate_vm_sandbox(sandbox) + validate_vm_sandbox(sandbox, self.config.gpu_enabled) } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result { - validate_vm_sandbox(sandbox)?; + validate_vm_sandbox(sandbox, self.config.gpu_enabled)?; if self.registry.lock().await.contains_key(&sandbox.id) { return Err(Status::already_exists("sandbox already exists")); } + let spec = sandbox.spec.as_ref(); + let is_gpu = spec.is_some_and(|s| s.gpu); + let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); + let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id); let rootfs = state_dir.join("rootfs"); @@ -256,21 +299,30 @@ impl VmDriver { })?; } + let gpu_bdf = if is_gpu { + let inventory = self + .gpu_inventory + .as_ref() + .ok_or_else(|| Status::internal("GPU inventory not initialized"))?; + let assignment = inventory + .lock() + .map_err(|e| Status::internal(format!("GPU inventory lock poisoned: {e}")))? + .assign(&sandbox.id, gpu_device) + .map_err(|e| Status::failed_precondition(e))?; + tracing::info!( + sandbox_id = %sandbox.id, + bdf = %assignment.bdf, + gpu_name = %assignment.name, + iommu_group = assignment.iommu_group, + "assigned GPU to sandbox" + ); + Some(assignment.bdf) + } else { + None + }; + let console_output = state_dir.join("rootfs-console.log"); let mut command = Command::new(&self.launcher_bin); - // Intentionally DO NOT set kill_on_drop(true). On a signal-driven - // driver exit (SIGKILL, SIGTERM without a handler, panic), - // tokio's Drop is racy with the launcher's procguard-initiated - // cleanup: if kill_on_drop SIGKILLs the launcher first, its - // cleanup callback never gets to SIGTERM gvproxy, and gvproxy is - // reparented to init as an orphan. Instead the whole cleanup - // cascade runs via procguard: - // driver exits → launcher's kqueue (macOS) or PR_SET_PDEATHSIG - // (Linux) fires → launcher kills gvproxy + libkrun fork → - // launcher exits → its own children die under pdeathsig. - // The explicit Drop path in VmProcess::terminate_vm_process still - // handles voluntary `delete_sandbox` teardown cleanly, where we - // do want SIGTERM + wait + SIGKILL semantics. command.stdin(Stdio::null()); command.stdout(Stdio::inherit()); command.stderr(Stdio::inherit()); @@ -278,21 +330,81 @@ impl VmDriver { command.arg("--vm-rootfs").arg(&rootfs); command.arg("--vm-exec").arg(sandbox_guest_init_path()); command.arg("--vm-workdir").arg("/"); - command.arg("--vm-vcpus").arg(self.config.vcpus.to_string()); - command - .arg("--vm-mem-mib") - .arg(self.config.mem_mib.to_string()); - command - .arg("--vm-krun-log-level") - .arg(self.config.krun_log_level.to_string()); command.arg("--vm-console-output").arg(&console_output); - for env in build_guest_environment(sandbox, &self.config) { + + // Compute the endpoint override before building the env so + // there is a single OPENSHELL_ENDPOINT value in the env list. + let endpoint_override = if gpu_bdf.is_some() { + let subnet = match self + .subnet_allocator + .lock() + .map_err(|e| Status::internal(format!("subnet allocator lock poisoned: {e}"))) + .and_then(|mut alloc| { + alloc + .allocate(&sandbox.id) + .map_err(|e| Status::failed_precondition(e)) + }) { + Ok(s) => s, + Err(err) => { + self.release_gpu_and_subnet(&sandbox.id); + let _ = tokio::fs::remove_dir_all(&state_dir).await; + return Err(err); + } + }; + let vsock_cid = allocate_vsock_cid(); + let mac = mac_from_sandbox_id(&sandbox.id); + let mac_str = format!( + "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", + mac[0], mac[1], mac[2], mac[3], mac[4], mac[5] + ); + let tap = tap_device_name(&sandbox.id); + + let tap_endpoint = guest_visible_openshell_endpoint_for_tap( + &self.config.openshell_endpoint, + &subnet.host_ip.to_string(), + ); + + command.arg("--vm-backend").arg("qemu"); + command + .arg("--vm-vcpus") + .arg(self.config.gpu_vcpus.to_string()); + command + .arg("--vm-mem-mib") + .arg(self.config.gpu_mem_mib.to_string()); + command + .arg("--vm-krun-log-level") + .arg(self.config.krun_log_level.to_string()); + command.arg("--vm-gpu-bdf").arg(gpu_bdf.as_ref().unwrap()); + command.arg("--vm-tap-device").arg(&tap); + command + .arg("--vm-guest-ip") + .arg(subnet.guest_ip.to_string()); + command.arg("--vm-host-ip").arg(subnet.host_ip.to_string()); + command.arg("--vm-vsock-cid").arg(vsock_cid.to_string()); + command.arg("--vm-guest-mac").arg(&mac_str); + + Some(tap_endpoint) + } else { + command.arg("--vm-vcpus").arg(self.config.vcpus.to_string()); + command + .arg("--vm-mem-mib") + .arg(self.config.mem_mib.to_string()); + command + .arg("--vm-krun-log-level") + .arg(self.config.krun_log_level.to_string()); + None + }; + + for env in build_guest_environment(sandbox, &self.config, endpoint_override.as_deref()) { command.arg("--vm-env").arg(env); } let child = match command.spawn() { Ok(child) => child, Err(err) => { + if gpu_bdf.is_some() { + self.release_gpu_and_subnet(&sandbox.id); + } let _ = tokio::fs::remove_dir_all(&state_dir).await; return Err(Status::internal(format!( "failed to launch vm helper '{}': {err}", @@ -314,6 +426,7 @@ impl VmDriver { snapshot: snapshot.clone(), state_dir: state_dir.clone(), process: process.clone(), + gpu_bdf: gpu_bdf.clone(), }, ); } @@ -338,21 +451,31 @@ impl VmDriver { let record = { let registry = self.registry.lock().await; if let Some((id, record)) = registry.get_key_value(sandbox_id) { - Some((id.clone(), record.state_dir.clone(), record.process.clone())) + Some(( + id.clone(), + record.state_dir.clone(), + record.process.clone(), + record.gpu_bdf.clone(), + )) } else { let matched_id = registry .iter() .find(|(_, record)| record.snapshot.name == sandbox_name) .map(|(id, _)| id.clone()); matched_id.and_then(|id| { - registry - .get(&id) - .map(|record| (id, record.state_dir.clone(), record.process.clone())) + registry.get(&id).map(|record| { + ( + id, + record.state_dir.clone(), + record.process.clone(), + record.gpu_bdf.clone(), + ) + }) }) } }; - let Some((record_id, state_dir, process)) = record else { + let Some((record_id, state_dir, process, gpu_bdf)) = record else { return Ok(DeleteSandboxResponse { deleted: false }); }; @@ -371,6 +494,10 @@ impl VmDriver { .map_err(|err| Status::internal(format!("failed to stop vm: {err}")))?; } + if gpu_bdf.is_some() { + self.release_gpu_and_subnet(&record_id); + } + if let Err(err) = tokio::fs::remove_dir_all(&state_dir).await && err.kind() != std::io::ErrorKind::NotFound { @@ -417,6 +544,17 @@ impl VmDriver { snapshots } + fn release_gpu_and_subnet(&self, sandbox_id: &str) { + if let Some(ref inventory) = self.gpu_inventory { + if let Ok(mut inv) = inventory.lock() { + inv.release(sandbox_id); + } + } + if let Ok(mut alloc) = self.subnet_allocator.lock() { + alloc.release(sandbox_id); + } + } + /// Watch the launcher child process and surface errors as driver /// conditions. /// @@ -487,6 +625,16 @@ impl VmDriver { sandbox_id.clone(), platform_event("vm", "Warning", "ProcessExited", message), ); + let has_gpu = { + let registry = self.registry.lock().await; + registry + .get(&sandbox_id) + .and_then(|r| r.gpu_bdf.as_ref()) + .is_some() + }; + if has_gpu { + self.release_gpu_and_subnet(&sandbox_id); + } return; } @@ -678,16 +826,35 @@ impl ComputeDriver for VmDriver { } } -fn validate_vm_sandbox(sandbox: &Sandbox) -> Result<(), Status> { +#[cfg(target_os = "linux")] +#[allow(unsafe_code)] +fn check_gpu_privileges() -> Result<(), String> { + if unsafe { libc::geteuid() } != 0 { + return Err( + "GPU support requires root privileges for VFIO bind/unbind and TAP networking. \ + Run with sudo or ensure CAP_SYS_ADMIN + CAP_NET_ADMIN capabilities are set." + .to_string(), + ); + } + Ok(()) +} + +fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { let spec = sandbox .spec .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - if spec.gpu { + + if spec.gpu && !gpu_enabled { return Err(Status::failed_precondition( - "vm sandboxes do not support gpu=true", + "GPU support is not enabled on this driver; start with --gpu", )); } + + if !spec.gpu && !spec.gpu_device.is_empty() { + return Err(Status::invalid_argument("gpu_device requires gpu=true")); + } + if let Some(template) = spec.template.as_ref() { if !template.image.is_empty() { return Err(Status::failed_precondition( @@ -744,7 +911,25 @@ fn guest_visible_openshell_endpoint(endpoint: &str) -> String { endpoint.to_string() } -fn build_guest_environment(sandbox: &Sandbox, config: &VmDriverConfig) -> Vec { +fn guest_visible_openshell_endpoint_for_tap(endpoint: &str, host_ip: &str) -> String { + let Ok(mut url) = Url::parse(endpoint) else { + return endpoint.to_string(); + }; + if url.set_host(Some(host_ip)).is_ok() { + url.to_string() + } else { + endpoint.to_string() + } +} + +fn build_guest_environment( + sandbox: &Sandbox, + config: &VmDriverConfig, + endpoint_override: Option<&str>, +) -> Vec { + let openshell_endpoint = endpoint_override + .map(String::from) + .unwrap_or_else(|| guest_visible_openshell_endpoint(&config.openshell_endpoint)); let mut environment = HashMap::from([ ("HOME".to_string(), "/root".to_string()), ( @@ -752,10 +937,7 @@ fn build_guest_environment(sandbox: &Sandbox, config: &VmDriverConfig) -> Vec Vec i64 { #[cfg(test)] mod tests { use super::*; + use crate::gpu::{SubnetAllocator, allocate_vsock_cid, mac_from_sandbox_id, tap_device_name}; use openshell_core::proto::compute::v1::{ DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, }; @@ -945,7 +1132,7 @@ mod tests { use tonic::Code; #[test] - fn validate_vm_sandbox_rejects_gpu() { + fn validate_vm_sandbox_rejects_gpu_when_not_enabled() { let sandbox = Sandbox { spec: Some(SandboxSpec { gpu: true, @@ -953,9 +1140,38 @@ mod tests { }), ..Default::default() }; - let err = validate_vm_sandbox(&sandbox).expect_err("gpu should be rejected"); + let err = validate_vm_sandbox(&sandbox, false) + .expect_err("gpu should be rejected when not enabled"); assert_eq!(err.code(), Code::FailedPrecondition); - assert!(err.message().contains("gpu")); + assert!(err.message().contains("GPU support is not enabled")); + } + + #[test] + fn validate_vm_sandbox_accepts_gpu_when_enabled() { + let sandbox = Sandbox { + spec: Some(SandboxSpec { + gpu: true, + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("gpu should be accepted when enabled"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { + let sandbox = Sandbox { + spec: Some(SandboxSpec { + gpu: false, + gpu_device: "0000:2d:00.0".to_string(), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("gpu_device without gpu should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("gpu_device requires gpu=true")); } #[test] @@ -979,7 +1195,8 @@ mod tests { }), ..Default::default() }; - let err = validate_vm_sandbox(&sandbox).expect_err("platform config should be rejected"); + let err = + validate_vm_sandbox(&sandbox, false).expect_err("platform config should be rejected"); assert_eq!(err.code(), Code::FailedPrecondition); assert!(err.message().contains("platform_config")); } @@ -1019,7 +1236,7 @@ mod tests { ..Default::default() }; - let env = build_guest_environment(&sandbox, &config); + let env = build_guest_environment(&sandbox, &config, None); assert!(env.contains(&"HOME=/root".to_string())); assert!(env.contains(&format!( "OPENSHELL_ENDPOINT=http://{GVPROXY_GATEWAY_IP}:8080/" @@ -1028,6 +1245,39 @@ mod tests { assert!(env.contains(&format!( "OPENSHELL_SSH_SOCKET_PATH={GUEST_SSH_SOCKET_PATH}" ))); + assert!( + env.contains(&"OPENSHELL_SSH_HANDSHAKE_SECRET=secret".to_string()), + "SSH handshake secret must be passed to the guest" + ); + } + + #[test] + fn build_guest_environment_uses_endpoint_override_for_tap() { + let config = VmDriverConfig { + openshell_endpoint: "http://127.0.0.1:8080".to_string(), + ssh_handshake_secret: "secret".to_string(), + ..Default::default() + }; + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + name: "sandbox-123".to_string(), + spec: Some(SandboxSpec::default()), + ..Default::default() + }; + + let env = build_guest_environment(&sandbox, &config, Some("http://10.0.128.1:8080")); + assert!( + env.contains(&"OPENSHELL_ENDPOINT=http://10.0.128.1:8080".to_string()), + "TAP endpoint override must replace the default" + ); + let endpoint_count = env + .iter() + .filter(|e| e.starts_with("OPENSHELL_ENDPOINT=")) + .count(); + assert_eq!( + endpoint_count, 1, + "must have exactly one OPENSHELL_ENDPOINT" + ); } #[test] @@ -1085,7 +1335,7 @@ mod tests { ..Default::default() }; - let env = build_guest_environment(&sandbox, &config); + let env = build_guest_environment(&sandbox, &config, None); assert!(env.contains(&format!("OPENSHELL_TLS_CA={GUEST_TLS_CA_PATH}"))); assert!(env.contains(&format!("OPENSHELL_TLS_CERT={GUEST_TLS_CERT_PATH}"))); assert!(env.contains(&format!("OPENSHELL_TLS_KEY={GUEST_TLS_KEY_PATH}"))); @@ -1111,6 +1361,12 @@ mod tests { launcher_bin: PathBuf::from("openshell-driver-vm"), registry: Arc::new(Mutex::new(HashMap::new())), events, + gpu_inventory: None, + gpu_count: 0, + subnet_allocator: Arc::new(std::sync::Mutex::new(SubnetAllocator::new( + Ipv4Addr::new(10, 0, 128, 0), + 17, + ))), }; let base = unique_temp_dir(); @@ -1233,6 +1489,43 @@ mod tests { let _ = std::fs::remove_dir_all(base); } + #[test] + fn subnet_allocator_assigns_and_releases() { + let mut alloc = SubnetAllocator::new(Ipv4Addr::new(10, 0, 128, 0), 17); + let s1 = alloc.allocate("sandbox-1").unwrap(); + assert_eq!(s1.host_ip, Ipv4Addr::new(10, 0, 128, 1)); + assert_eq!(s1.guest_ip, Ipv4Addr::new(10, 0, 128, 2)); + assert_eq!(s1.prefix_len, 30); + + let s2 = alloc.allocate("sandbox-2").unwrap(); + assert_ne!(s1.host_ip, s2.host_ip); + + alloc.release("sandbox-1"); + let s3 = alloc.allocate("sandbox-3").unwrap(); + assert!(s3.host_ip != s2.host_ip); + } + + #[test] + fn tap_device_name_fits_ifnamsiz() { + let name = tap_device_name("sandbox-abc-def-ghi"); + assert!(name.len() <= 15); + assert!(name.starts_with("vmtap-")); + } + + #[test] + fn mac_address_is_locally_administered() { + let mac = mac_from_sandbox_id("test-sandbox"); + assert_eq!(mac[0] & 0x02, 0x02); + assert_eq!(mac[0] & 0x01, 0x00); + } + + #[test] + fn vsock_cid_monotonically_increases() { + let cid1 = allocate_vsock_cid(); + let cid2 = allocate_vsock_cid(); + assert!(cid2 > cid1); + } + fn unique_temp_dir() -> PathBuf { static COUNTER: AtomicU64 = AtomicU64::new(0); let nanos = SystemTime::now() @@ -1280,6 +1573,7 @@ mod tests { snapshot: sandbox, state_dir, process, + gpu_bdf: None, }, ); } diff --git a/crates/openshell-driver-vm/src/gpu.rs b/crates/openshell-driver-vm/src/gpu.rs new file mode 100644 index 000000000..9089a166b --- /dev/null +++ b/crates/openshell-driver-vm/src/gpu.rs @@ -0,0 +1,316 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use openshell_vfio::{ + GpuBindGuard, GpuBindState, GpuBinding, GpuInfo, SysfsRoot, prepare_gpu_for_passthrough, + probe_host_nvidia_vfio_readiness, reconcile_stale_bindings, validate_bdf, +}; +use std::collections::HashMap; +use std::net::Ipv4Addr; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU32, Ordering}; + +/// Tracks available GPUs and their assignment to sandboxes. +pub struct GpuInventory { + slots: Vec, + sysfs: SysfsRoot, + state_path: PathBuf, +} + +struct GpuSlot { + info: GpuInfo, + assigned_to: Option, + bind_guard: Option, +} + +impl GpuInventory { + pub fn new(sysfs: SysfsRoot, state_dir: &Path) -> Self { + let state_path = state_dir.join("gpu-bindings.json"); + + let restored = reconcile_stale_bindings(&sysfs, &state_path); + for bdf in &restored { + tracing::info!(bdf = %bdf, "restored stale GPU binding from previous crash"); + } + + let gpus = probe_host_nvidia_vfio_readiness(&sysfs); + let slots = gpus + .into_iter() + .map(|info| GpuSlot { + info, + assigned_to: None, + bind_guard: None, + }) + .collect(); + + Self { + slots, + sysfs, + state_path, + } + } + + pub fn gpu_count(&self) -> u32 { + self.slots.len() as u32 + } + + pub fn available_count(&self) -> u32 { + self.slots + .iter() + .filter(|s| s.assigned_to.is_none()) + .count() as u32 + } + + /// Assign a GPU to a sandbox. Returns the assignment details including BDF. + pub fn assign(&mut self, sandbox_id: &str, gpu_device: &str) -> Result { + let slot_idx = if gpu_device.is_empty() { + self.slots + .iter() + .position(|s| s.assigned_to.is_none()) + .ok_or_else(|| "all GPUs are currently assigned to other sandboxes".to_string())? + } else if let Ok(idx) = gpu_device.parse::() { + if idx >= self.slots.len() { + return Err(format!( + "GPU index {idx} out of range (have {} GPUs)", + self.slots.len() + )); + } + if self.slots[idx].assigned_to.is_some() { + return Err(format!( + "GPU at index {idx} ({}) is already assigned to another sandbox", + self.slots[idx].info.bdf + )); + } + idx + } else { + validate_bdf(gpu_device).map_err(|e| e.to_string())?; + let idx = self + .slots + .iter() + .position(|s| s.info.bdf == gpu_device) + .ok_or_else(|| format!("GPU {gpu_device} not found in inventory"))?; + if self.slots[idx].assigned_to.is_some() { + return Err(format!( + "GPU {gpu_device} is already assigned to another sandbox" + )); + } + idx + }; + + let bdf = self.slots[slot_idx].info.bdf.clone(); + let guard = prepare_gpu_for_passthrough(&self.sysfs, &bdf) + .map_err(|e| format!("failed to prepare GPU {bdf} for passthrough: {e}"))?; + + self.slots[slot_idx].assigned_to = Some(sandbox_id.to_string()); + self.slots[slot_idx].bind_guard = Some(guard); + self.persist_state(); + + Ok(GpuAssignment { + bdf, + name: self.slots[slot_idx].info.name.clone(), + iommu_group: self.slots[slot_idx].info.iommu_group, + }) + } + + /// Release a GPU assignment. The `GpuBindGuard` is dropped, restoring the GPU. + pub fn release(&mut self, sandbox_id: &str) { + if let Some(slot) = self + .slots + .iter_mut() + .find(|s| s.assigned_to.as_deref() == Some(sandbox_id)) + { + let bdf = slot.info.bdf.clone(); + slot.assigned_to = None; + slot.bind_guard.take(); + self.persist_state(); + tracing::info!(bdf = %bdf, sandbox_id = %sandbox_id, "released GPU assignment"); + } + } + + fn persist_state(&self) { + let bindings: Vec = self + .slots + .iter() + .filter_map(|s| { + s.assigned_to.as_ref().map(|id| GpuBinding { + bdf: s.info.bdf.clone(), + sandbox_id: id.clone(), + bound_at_ms: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |d| d.as_millis() as i64), + }) + }) + .collect(); + let state = GpuBindState { bindings }; + if let Err(err) = state.save(&self.state_path) { + tracing::warn!(error = %err, "failed to persist GPU bind state"); + } + } +} + +pub struct GpuAssignment { + pub bdf: String, + pub name: String, + pub iommu_group: u32, +} + +// --------------------------------------------------------------------------- +// Subnet allocation for per-sandbox TAP networking +// --------------------------------------------------------------------------- + +/// Allocates /30 subnets from a pool for per-sandbox TAP networking. +pub struct SubnetAllocator { + base: Ipv4Addr, + prefix_len: u8, + next_offset: u32, + allocated: HashMap, +} + +pub struct SubnetAllocation { + pub host_ip: Ipv4Addr, + pub guest_ip: Ipv4Addr, + pub prefix_len: u8, + pub offset: u32, +} + +static NEXT_VSOCK_CID: AtomicU32 = AtomicU32::new(3); + +impl SubnetAllocator { + pub fn new(base: Ipv4Addr, prefix_len: u8) -> Self { + Self { + base, + prefix_len, + next_offset: 0, + allocated: HashMap::new(), + } + } + + pub fn allocate(&mut self, sandbox_id: &str) -> Result { + let pool_size = 1u32 << (32 - self.prefix_len); + let max_subnets = pool_size / 4; + + if self.allocated.len() as u32 >= max_subnets { + return Err("subnet pool exhausted".to_string()); + } + + while self + .allocated + .values() + .any(|a| a.offset == self.next_offset) + { + self.next_offset = (self.next_offset + 1) % max_subnets; + } + + let base_u32 = u32::from(self.base); + let subnet_base = base_u32 + (self.next_offset * 4); + let host_ip = Ipv4Addr::from(subnet_base + 1); + let guest_ip = Ipv4Addr::from(subnet_base + 2); + + let allocation = SubnetAllocation { + host_ip, + guest_ip, + prefix_len: 30, + offset: self.next_offset, + }; + + self.allocated.insert(sandbox_id.to_string(), allocation); + self.next_offset = (self.next_offset + 1) % max_subnets; + + let alloc = &self.allocated[sandbox_id]; + Ok(SubnetAllocation { + host_ip: alloc.host_ip, + guest_ip: alloc.guest_ip, + prefix_len: alloc.prefix_len, + offset: alloc.offset, + }) + } + + pub fn release(&mut self, sandbox_id: &str) { + self.allocated.remove(sandbox_id); + } +} + +pub fn allocate_vsock_cid() -> u32 { + NEXT_VSOCK_CID.fetch_add(1, Ordering::Relaxed) +} + +/// Generate a locally-administered MAC from sandbox ID using FNV-1a. +pub fn mac_from_sandbox_id(sandbox_id: &str) -> [u8; 6] { + let mut hash: u64 = 0xcbf29ce484222325; + for byte in sandbox_id.as_bytes() { + hash ^= u64::from(*byte); + hash = hash.wrapping_mul(0x100000001b3); + } + let bytes = hash.to_le_bytes(); + let mut mac = [bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]]; + mac[0] = (mac[0] & 0xFE) | 0x02; + mac +} + +/// TAP device name from sandbox ID (fits `IFNAMSIZ=16`). +pub fn tap_device_name(sandbox_id: &str) -> String { + let end = sandbox_id.len().min(8); + let end = sandbox_id.floor_char_boundary(end); + let prefix = &sandbox_id[..end]; + format!("vmtap-{prefix}") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn subnet_allocator_assigns_sequential_blocks() { + let mut alloc = SubnetAllocator::new(Ipv4Addr::new(10, 0, 128, 0), 17); + + let s1 = alloc.allocate("sandbox-1").unwrap(); + assert_eq!(s1.host_ip, Ipv4Addr::new(10, 0, 128, 1)); + assert_eq!(s1.guest_ip, Ipv4Addr::new(10, 0, 128, 2)); + assert_eq!(s1.prefix_len, 30); + + let s2 = alloc.allocate("sandbox-2").unwrap(); + assert_eq!(s2.host_ip, Ipv4Addr::new(10, 0, 128, 5)); + assert_eq!(s2.guest_ip, Ipv4Addr::new(10, 0, 128, 6)); + } + + #[test] + fn subnet_allocator_recycles_after_release() { + let mut alloc = SubnetAllocator::new(Ipv4Addr::new(10, 0, 128, 0), 17); + + let _s1 = alloc.allocate("sandbox-1").unwrap(); + let _s2 = alloc.allocate("sandbox-2").unwrap(); + alloc.release("sandbox-1"); + + let s3 = alloc.allocate("sandbox-3").unwrap(); + assert_eq!(s3.host_ip, Ipv4Addr::new(10, 0, 128, 9)); + } + + #[test] + fn tap_device_name_truncates_long_ids() { + assert_eq!(tap_device_name("abc"), "vmtap-abc"); + assert_eq!(tap_device_name("abcdefghijklmnop"), "vmtap-abcdefgh"); + } + + #[test] + fn mac_from_sandbox_id_sets_locally_administered_bit() { + let mac = mac_from_sandbox_id("sandbox-123"); + assert_eq!(mac[0] & 0x02, 0x02, "locally-administered bit must be set"); + assert_eq!(mac[0] & 0x01, 0x00, "multicast bit must be clear"); + } + + #[test] + fn mac_from_sandbox_id_deterministic() { + let mac1 = mac_from_sandbox_id("sandbox-x"); + let mac2 = mac_from_sandbox_id("sandbox-x"); + assert_eq!(mac1, mac2); + + let mac3 = mac_from_sandbox_id("sandbox-y"); + assert_ne!(mac1, mac3); + } + + #[test] + fn vsock_cid_increments() { + let cid1 = allocate_vsock_cid(); + let cid2 = allocate_vsock_cid(); + assert_eq!(cid2, cid1 + 1); + } +} diff --git a/crates/openshell-driver-vm/src/lib.rs b/crates/openshell-driver-vm/src/lib.rs index 772db47b3..e4f9e1299 100644 --- a/crates/openshell-driver-vm/src/lib.rs +++ b/crates/openshell-driver-vm/src/lib.rs @@ -4,9 +4,10 @@ pub mod driver; mod embedded_runtime; mod ffi; +pub mod gpu; pub mod procguard; mod rootfs; mod runtime; pub use driver::{VmDriver, VmDriverConfig}; -pub use runtime::{VM_RUNTIME_DIR_ENV, VmLaunchConfig, configured_runtime_dir, run_vm}; +pub use runtime::{VM_RUNTIME_DIR_ENV, VmBackend, VmLaunchConfig, configured_runtime_dir, run_vm}; diff --git a/crates/openshell-driver-vm/src/main.rs b/crates/openshell-driver-vm/src/main.rs index 5a675e78a..38311f745 100644 --- a/crates/openshell-driver-vm/src/main.rs +++ b/crates/openshell-driver-vm/src/main.rs @@ -5,10 +5,7 @@ use clap::Parser; use miette::{IntoDiagnostic, Result}; use openshell_core::VERSION; use openshell_core::proto::compute::v1::compute_driver_server::ComputeDriverServer; -use openshell_driver_vm::{ - VM_RUNTIME_DIR_ENV, VmDriver, VmDriverConfig, VmLaunchConfig, configured_runtime_dir, - procguard, run_vm, -}; +use openshell_driver_vm::{VmBackend, VmDriver, VmDriverConfig, VmLaunchConfig, procguard, run_vm}; use std::net::SocketAddr; use std::path::PathBuf; use tokio::net::UnixListener; @@ -93,6 +90,36 @@ struct Args { #[arg(long, env = "OPENSHELL_VM_DRIVER_MEM_MIB", default_value_t = 2048)] mem_mib: u32, + + #[arg(long, env = "OPENSHELL_VM_GPU")] + gpu: bool, + + #[arg(long, env = "OPENSHELL_VM_GPU_MEM_MIB", default_value_t = 8192)] + gpu_mem_mib: u32, + + #[arg(long, env = "OPENSHELL_VM_GPU_VCPUS", default_value_t = 4)] + gpu_vcpus: u8, + + #[arg(long, hide = true)] + vm_backend: Option, + + #[arg(long, hide = true)] + vm_gpu_bdf: Option, + + #[arg(long, hide = true)] + vm_tap_device: Option, + + #[arg(long, hide = true)] + vm_guest_ip: Option, + + #[arg(long, hide = true)] + vm_host_ip: Option, + + #[arg(long, hide = true)] + vm_vsock_cid: Option, + + #[arg(long, hide = true)] + vm_guest_mac: Option, } #[tokio::main] @@ -146,6 +173,9 @@ async fn main() -> Result<()> { guest_tls_ca: args.guest_tls_ca, guest_tls_cert: args.guest_tls_cert, guest_tls_key: args.guest_tls_key, + gpu_enabled: args.gpu, + gpu_mem_mib: args.gpu_mem_mib, + gpu_vcpus: args.gpu_vcpus, }) .await .map_err(|err| miette::miette!("{err}"))?; @@ -193,6 +223,12 @@ fn build_vm_launch_config(args: &Args) -> std::result::Result VmBackend::Qemu, + Some("libkrun") | None => VmBackend::Libkrun, + Some(other) => return Err(format!("unknown VM backend: {other}")), + }; + Ok(VmLaunchConfig { rootfs, vcpus: args.vm_vcpus, @@ -203,6 +239,13 @@ fn build_vm_launch_config(args: &Args) -> std::result::Result, + pub tap_device: Option, + pub guest_ip: Option, + pub host_ip: Option, + pub vsock_cid: Option, + pub guest_mac: Option, } pub fn run_vm(config: &VmLaunchConfig) -> Result<(), String> { + match config.backend { + VmBackend::Qemu => run_qemu_vm(config), + VmBackend::Libkrun => run_libkrun_vm(config), + } +} + +fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { + let gpu_bdf = config + .gpu_bdf + .as_deref() + .ok_or("gpu_bdf is required for QEMU backend")?; + let tap_device = config + .tap_device + .as_deref() + .ok_or("tap_device is required for QEMU backend")?; + let guest_mac = config + .guest_mac + .as_deref() + .ok_or("guest_mac is required for QEMU backend")?; + let vsock_cid = config + .vsock_cid + .ok_or("vsock_cid is required for QEMU backend")?; + let _guest_ip = config + .guest_ip + .as_deref() + .ok_or("guest_ip is required for QEMU backend")?; + let host_ip = config + .host_ip + .as_deref() + .ok_or("host_ip is required for QEMU backend")?; + + if !config.rootfs.is_dir() { + return Err(format!( + "rootfs directory not found: {}", + config.rootfs.display() + )); + } + + if let Err(err) = procguard::die_with_parent_cleanup(procguard_kill_children) { + return Err(format!("procguard arm failed: {err}")); + } + + #[cfg(target_os = "linux")] + check_kvm_access()?; + + write_guest_env_file(&config.rootfs, &config.env)?; + + let rootfs_str = config.rootfs.to_str().ok_or("rootfs path not UTF-8")?; + let sandbox_dir = config.rootfs.parent().unwrap_or(&config.rootfs); + let sock_prefix = tap_device.trim_start_matches("vmtap-"); + let virtiofsd_sock_dir = PathBuf::from(format!("/tmp/ovm-qemu-{sock_prefix}")); + std::fs::create_dir_all(&virtiofsd_sock_dir) + .map_err(|e| format!("create virtiofsd sock dir: {e}"))?; + let virtiofsd_sock = virtiofsd_sock_dir.join("virtiofsd.sock"); + let shm_path = format!("/dev/shm/ovm-qemu-{sock_prefix}"); + + std::fs::create_dir_all(&shm_path).map_err(|e| format!("create shm dir: {e}"))?; + + let runtime_dir = configured_runtime_dir()?; + + setup_tap_networking(tap_device, host_ip)?; + let mut tap_guard = TapGuard::new(tap_device.to_string(), host_ip.to_string()); + + let virtiofsd_log = sandbox_dir.join("virtiofsd.log"); + let virtiofsd_log_file = + std::fs::File::create(&virtiofsd_log).map_err(|e| format!("create virtiofsd log: {e}"))?; + + let virtiofsd_bin = { + let runtime_virtiofsd = runtime_dir.join("virtiofsd"); + if runtime_virtiofsd.is_file() { + runtime_virtiofsd + } else { + PathBuf::from("virtiofsd") + } + }; + + let mut virtiofsd_cmd = StdCommand::new(&virtiofsd_bin); + virtiofsd_cmd + .arg("--socket-path") + .arg(&virtiofsd_sock) + .arg("--shared-dir") + .arg(rootfs_str) + .arg("--cache=auto") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(virtiofsd_log_file); + + #[cfg(target_os = "linux")] + { + use nix::sys::signal::Signal; + use std::os::unix::process::CommandExt as _; + unsafe { + virtiofsd_cmd.pre_exec(|| { + nix::sys::prctl::set_pdeathsig(Signal::SIGKILL) + .map_err(|err| std::io::Error::other(format!("pdeathsig: {err}"))) + }); + } + } + + let virtiofsd_child = virtiofsd_cmd + .spawn() + .map_err(|e| format!("failed to start virtiofsd: {e}"))?; + let virtiofsd_pid = virtiofsd_child.id() as i32; + GVPROXY_PID.store(virtiofsd_pid, Ordering::Relaxed); + let mut virtiofsd_guard = GvproxyGuard::new(virtiofsd_child); + + wait_for_path(&virtiofsd_sock, Duration::from_secs(5), "virtiofsd socket")?; + + let vmlinux = runtime_dir.join("vmlinux"); + if !vmlinux.is_file() { + return Err(format!("VM kernel not found: {}", vmlinux.display())); + } + + let kernel_cmdline = build_kernel_cmdline(config); + + let mut qemu_cmd = StdCommand::new("qemu-system-x86_64"); + qemu_cmd + .arg("-machine") + .arg("q35,accel=kvm") + .arg("-cpu") + .arg("host") + .arg("-smp") + .arg(config.vcpus.to_string()) + .arg("-m") + .arg(format!("{}M", config.mem_mib)) + .arg("-nographic") + .arg("-no-reboot") + .arg("-kernel") + .arg(&vmlinux) + .arg("-append") + .arg(&kernel_cmdline) + .arg("-chardev") + .arg(format!( + "socket,id=virtiofs,path={}", + virtiofsd_sock.display() + )) + .arg("-device") + .arg("vhost-user-fs-pci,chardev=virtiofs,tag=rootfs") + .arg("-object") + .arg(format!( + "memory-backend-memfd,id=mem,size={}M,share=on", + config.mem_mib + )) + .arg("-numa") + .arg("node,memdev=mem") + .arg("-netdev") + .arg(format!( + "tap,id=net0,ifname={tap_device},script=no,downscript=no" + )) + .arg("-device") + .arg(format!("virtio-net-pci,netdev=net0,mac={guest_mac}")) + .arg("-device") + .arg("pcie-root-port,id=vsock_root,slot=1") + .arg("-device") + .arg(format!( + "vhost-vsock-pci,guest-cid={vsock_cid},bus=vsock_root" + )) + .arg("-device") + .arg("pcie-root-port,id=gpu_root,slot=2") + .arg("-device") + .arg(format!("vfio-pci,host={gpu_bdf},bus=gpu_root")) + .arg("-serial") + .arg(format!("file:{}", config.console_output.display())); + + qemu_cmd.stdin(Stdio::null()); + qemu_cmd.stdout(Stdio::inherit()); + qemu_cmd.stderr(Stdio::inherit()); + + #[cfg(target_os = "linux")] + { + use nix::sys::signal::Signal; + use std::os::unix::process::CommandExt as _; + unsafe { + qemu_cmd.pre_exec(|| { + nix::sys::prctl::set_pdeathsig(Signal::SIGKILL) + .map_err(|err| std::io::Error::other(format!("pdeathsig: {err}"))) + }); + } + } + + let mut qemu_child = qemu_cmd + .spawn() + .map_err(|e| format!("failed to start QEMU: {e}"))?; + + let qemu_pid = qemu_child.id() as i32; + install_signal_forwarding(qemu_pid); + + let status = qemu_child + .wait() + .map_err(|e| format!("failed to wait for QEMU: {e}"))?; + + CHILD_PID.store(0, Ordering::Relaxed); + unsafe { + libc::kill(virtiofsd_pid, libc::SIGTERM); + } + virtiofsd_guard.disarm(); + GVPROXY_PID.store(0, Ordering::Relaxed); + teardown_tap_networking(tap_device, host_ip); + tap_guard.disarm(); + let _ = std::fs::remove_dir_all(&shm_path); + let _ = std::fs::remove_dir_all(&virtiofsd_sock_dir); + + if status.success() { + Ok(()) + } else { + Err(format!("QEMU exited with status {status}")) + } +} + +/// Write environment variables into the rootfs so the guest init script +/// can source them. virtiofs shares the host rootfs directory into the guest. +fn write_guest_env_file(rootfs: &Path, env_vars: &[String]) -> Result<(), String> { + let srv_dir = rootfs.join("srv"); + std::fs::create_dir_all(&srv_dir).map_err(|e| format!("create /srv in rootfs: {e}"))?; + let env_file = srv_dir.join("openshell-env.sh"); + let mut content = String::new(); + for var in env_vars { + if let Some((key, value)) = var.split_once('=') { + content.push_str(&format!("export {key}=\"{}\"\n", shell_escape(value))); + } + } + std::fs::write(&env_file, &content).map_err(|e| format!("write guest env file: {e}"))?; + Ok(()) +} + +/// Escape a string for use inside bash double quotes. +fn shell_escape(s: &str) -> String { + s.replace('\\', "\\\\") + .replace('"', "\\\"") + .replace('$', "\\$") + .replace('`', "\\`") + .replace('\n', "\\n") + .replace('\r', "\\r") +} + +fn build_kernel_cmdline(config: &VmLaunchConfig) -> String { + let mut parts = vec![ + "console=ttyS0".to_string(), + "root=rootfs".to_string(), + "rootfstype=virtiofs".to_string(), + "rw".to_string(), + "panic=-1".to_string(), + format!("init={}", config.exec_path), + ]; + + if let Some(ip) = &config.guest_ip { + if let Some(host_ip) = &config.host_ip { + parts.push(format!("ip={ip}::{host_ip}:255.255.255.252:sandbox::off")); + parts.push(format!("VM_NET_IP={ip}")); + parts.push(format!("VM_NET_GW={host_ip}")); + } + } + + if let Some(dns) = host_dns_server() { + parts.push(format!("VM_NET_DNS={dns}")); + } + + if config.gpu_bdf.is_some() { + parts.push("GPU_ENABLED=true".to_string()); + parts.push("firmware_class.path=/lib/firmware".to_string()); + } + + parts.join(" ") +} + +fn host_dns_server() -> Option { + // Prefer systemd-resolved upstream config (skips the 127.0.0.53 + // stub listener which is unreachable from inside QEMU/TAP guests). + for path in &["/run/systemd/resolve/resolv.conf", "/etc/resolv.conf"] { + let Ok(resolv) = std::fs::read_to_string(path) else { + continue; + }; + for line in resolv.lines() { + let line = line.trim(); + if let Some(server) = line.strip_prefix("nameserver") { + let server = server.trim(); + if server == "127.0.0.53" || server.starts_with("127.") { + continue; + } + if !server.is_empty() { + return Some(server.to_string()); + } + } + } + } + None +} + +fn setup_tap_networking(tap_device: &str, host_ip: &str) -> Result<(), String> { + run_cmd("ip", &["tuntap", "add", "dev", tap_device, "mode", "tap"])?; + run_cmd( + "ip", + &["addr", "add", &format!("{host_ip}/30"), "dev", tap_device], + )?; + run_cmd("ip", &["link", "set", tap_device, "up"])?; + + enable_ip_forwarding()?; + + let subnet = tap_subnet_from_host_ip(host_ip); + let _ = run_cmd( + "iptables", + &[ + "-t", + "nat", + "-D", + "POSTROUTING", + "-s", + &subnet, + "-j", + "MASQUERADE", + ], + ); + run_cmd( + "iptables", + &[ + "-t", + "nat", + "-A", + "POSTROUTING", + "-s", + &subnet, + "-j", + "MASQUERADE", + ], + )?; + let _ = run_cmd( + "iptables", + &["-D", "FORWARD", "-i", tap_device, "-j", "ACCEPT"], + ); + run_cmd( + "iptables", + &["-A", "FORWARD", "-i", tap_device, "-j", "ACCEPT"], + )?; + let _ = run_cmd( + "iptables", + &[ + "-D", + "FORWARD", + "-o", + tap_device, + "-m", + "state", + "--state", + "RELATED,ESTABLISHED", + "-j", + "ACCEPT", + ], + ); + run_cmd( + "iptables", + &[ + "-A", + "FORWARD", + "-o", + tap_device, + "-m", + "state", + "--state", + "RELATED,ESTABLISHED", + "-j", + "ACCEPT", + ], + )?; + + Ok(()) +} + +fn teardown_tap_networking(tap_device: &str, host_ip: &str) { + let subnet = tap_subnet_from_host_ip(host_ip); + let _ = run_cmd( + "iptables", + &[ + "-D", + "FORWARD", + "-o", + tap_device, + "-m", + "state", + "--state", + "RELATED,ESTABLISHED", + "-j", + "ACCEPT", + ], + ); + let _ = run_cmd( + "iptables", + &["-D", "FORWARD", "-i", tap_device, "-j", "ACCEPT"], + ); + let _ = run_cmd( + "iptables", + &[ + "-t", + "nat", + "-D", + "POSTROUTING", + "-s", + &subnet, + "-j", + "MASQUERADE", + ], + ); + let _ = run_cmd("ip", &["link", "set", tap_device, "down"]); + let _ = run_cmd("ip", &["tuntap", "del", "dev", tap_device, "mode", "tap"]); +} + +fn tap_subnet_from_host_ip(host_ip: &str) -> String { + if let Ok(ip) = host_ip.parse::() { + let base = u32::from(ip) & !3; + let base_ip = std::net::Ipv4Addr::from(base); + format!("{base_ip}/30") + } else { + format!("{host_ip}/30") + } +} + +fn enable_ip_forwarding() -> Result<(), String> { + std::fs::write("/proc/sys/net/ipv4/ip_forward", "1") + .map_err(|e| format!("enable ip_forward: {e}")) +} + +fn run_cmd(cmd: &str, args: &[&str]) -> Result<(), String> { + let output = StdCommand::new(cmd) + .args(args) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .map_err(|e| format!("failed to run {cmd}: {e}"))?; + if output.status.success() { + Ok(()) + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + Err(format!("{cmd} {} failed: {stderr}", args.join(" "))) + } +} + +/// RAII guard that tears down TAP networking on drop. +struct TapGuard { + tap_device: String, + host_ip: String, + disarmed: bool, +} + +impl TapGuard { + fn new(tap_device: String, host_ip: String) -> Self { + Self { + tap_device, + host_ip, + disarmed: false, + } + } + + fn disarm(&mut self) { + self.disarmed = true; + } +} + +impl Drop for TapGuard { + fn drop(&mut self) { + if !self.disarmed { + teardown_tap_networking(&self.tap_device, &self.host_ip); + } + } +} + +/// Shared procguard cleanup callback for both libkrun and QEMU paths. +/// Only async-signal-safe calls: atomic loads and `kill(2)`. +fn procguard_kill_children() { + let helper_pid = GVPROXY_PID.load(Ordering::Relaxed); + let child_pid = CHILD_PID.load(Ordering::Relaxed); + if helper_pid > 0 { + unsafe { + libc::kill(helper_pid, libc::SIGTERM); + } + } + if child_pid > 0 { + unsafe { + libc::kill(child_pid, libc::SIGTERM); + } + } + std::thread::sleep(Duration::from_millis(200)); + if helper_pid > 0 { + unsafe { + libc::kill(helper_pid, libc::SIGKILL); + } + } + if child_pid > 0 { + unsafe { + libc::kill(child_pid, libc::SIGKILL); + } + } +} + +fn run_libkrun_vm(config: &VmLaunchConfig) -> Result<(), String> { if !config.rootfs.is_dir() { return Err(format!( "rootfs directory not found: {}", @@ -50,37 +558,7 @@ pub fn run_vm(config: &VmLaunchConfig) -> Result<(), String> { // those slots get populated later in this function. Only ONE arm // per process: racing two watchers for the same NOTE_EXIT event // would cause whichever wins to skip the cleanup. - if let Err(err) = procguard::die_with_parent_cleanup(|| { - // Cleanup order: SIGTERM gvproxy and the libkrun fork first so - // they can drain cleanly, then SIGKILL after a brief grace - // window. We can't rely on Rust destructors here; when - // procguard's watcher thread returns we call `std::process::exit` - // and the process tears down. Only async-signal-safe calls here: - // atomic loads and `kill(2)` are both on the POSIX list. - let gv_pid = GVPROXY_PID.load(Ordering::Relaxed); - let child_pid = CHILD_PID.load(Ordering::Relaxed); - if gv_pid > 0 { - unsafe { - libc::kill(gv_pid, libc::SIGTERM); - } - } - if child_pid > 0 { - unsafe { - libc::kill(child_pid, libc::SIGTERM); - } - } - std::thread::sleep(Duration::from_millis(200)); - if gv_pid > 0 { - unsafe { - libc::kill(gv_pid, libc::SIGKILL); - } - } - if child_pid > 0 { - unsafe { - libc::kill(child_pid, libc::SIGKILL); - } - } - }) { + if let Err(err) = procguard::die_with_parent_cleanup(procguard_kill_children) { return Err(format!("procguard arm failed: {err}")); } diff --git a/crates/openshell-driver-vm/start.sh b/crates/openshell-driver-vm/start.sh index 0579e8aa0..675bb4c2e 100755 --- a/crates/openshell-driver-vm/start.sh +++ b/crates/openshell-driver-vm/start.sh @@ -28,6 +28,13 @@ DRIVER_DIR="${OPENSHELL_DRIVER_DIR:-${DRIVER_DIR_DEFAULT}}" export OPENSHELL_VM_RUNTIME_COMPRESSED_DIR="${OPENSHELL_VM_RUNTIME_COMPRESSED_DIR:-${COMPRESSED_DIR}}" +for arg in "$@"; do + if [ "${arg}" = "--gpu" ]; then + export OPENSHELL_VM_GPU=true + break + fi +done + mkdir -p "${STATE_DIR}" normalize_bool() { @@ -73,13 +80,19 @@ check_supervisor_cross_toolchain() { fi } -if [ ! -f "${COMPRESSED_DIR}/rootfs.tar.zst" ]; then +if [ ! -s "${COMPRESSED_DIR}/rootfs.tar.zst" ]; then check_supervisor_cross_toolchain echo "==> Building base VM rootfs tarball" mise run vm:rootfs -- --base fi -if [ ! -f "${COMPRESSED_DIR}/rootfs.tar.zst" ] || ! find "${COMPRESSED_DIR}" -maxdepth 1 -name 'libkrun*.zst' | grep -q .; then +if [ "${OPENSHELL_VM_GPU:-}" = "true" ] && [ ! -s "${COMPRESSED_DIR}/rootfs-gpu.tar.zst" ]; then + check_supervisor_cross_toolchain + echo "==> Building GPU VM rootfs tarball" + mise run vm:rootfs -- --gpu +fi + +if [ ! -s "${COMPRESSED_DIR}/rootfs.tar.zst" ] || ! find "${COMPRESSED_DIR}" -maxdepth 1 -name 'libkrun*.zst' | grep -q .; then echo "==> Preparing embedded VM runtime" mise run vm:setup fi @@ -106,12 +119,18 @@ export OPENSHELL_SSH_GATEWAY_PORT="${OPENSHELL_SSH_GATEWAY_PORT:-${SERVER_PORT}} export OPENSHELL_SSH_HANDSHAKE_SECRET="${OPENSHELL_SSH_HANDSHAKE_SECRET:-dev-vm-driver-secret}" export OPENSHELL_VM_DRIVER_STATE_DIR="${STATE_DIR}" -echo "==> Gateway registration" +echo "==> Registering gateway" echo " Name: ${GATEWAY_NAME}" echo " Endpoint: ${LOCAL_GATEWAY_ENDPOINT}" -echo " Register: ${CLI_BIN} gateway add --name ${GATEWAY_NAME} ${LOCAL_GATEWAY_ENDPOINT}" -echo " Select: ${CLI_BIN} gateway select ${GATEWAY_NAME}" -echo " Driver: ${OPENSHELL_DRIVER_DIR}/openshell-driver-vm" +echo " Driver: ${OPENSHELL_DRIVER_DIR}/openshell-driver-vm" + +if [ -n "${SUDO_USER:-}" ]; then + sudo -u "${SUDO_USER}" "${CLI_BIN}" gateway destroy --name "${GATEWAY_NAME}" 2>/dev/null || true + sudo -u "${SUDO_USER}" "${CLI_BIN}" gateway add --name "${GATEWAY_NAME}" "${LOCAL_GATEWAY_ENDPOINT}" +else + "${CLI_BIN}" gateway destroy --name "${GATEWAY_NAME}" 2>/dev/null || true + "${CLI_BIN}" gateway add --name "${GATEWAY_NAME}" "${LOCAL_GATEWAY_ENDPOINT}" +fi echo "==> Starting OpenShell server with VM compute driver" exec "${ROOT}/target/debug/openshell-gateway" diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 8f3152b33..7b56ad9d4 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -1109,6 +1109,7 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { .as_ref() .map(driver_sandbox_template_from_public), gpu: spec.gpu, + gpu_device: spec.gpu_device.clone(), } } @@ -1662,6 +1663,7 @@ mod tests { driver_version: "test".to_string(), default_image: "openshell/sandbox:test".to_string(), supports_gpu: true, + gpu_count: 0, })) } diff --git a/crates/openshell-vfio/Cargo.toml b/crates/openshell-vfio/Cargo.toml new file mode 100644 index 000000000..b93500fdc --- /dev/null +++ b/crates/openshell-vfio/Cargo.toml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-vfio" +description = "VFIO GPU passthrough lifecycle for OpenShell VM sandboxes" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +tempfile = "3" + +[lints] +workspace = true diff --git a/crates/openshell-vfio/src/lib.rs b/crates/openshell-vfio/src/lib.rs new file mode 100644 index 000000000..74a3ac38f --- /dev/null +++ b/crates/openshell-vfio/src/lib.rs @@ -0,0 +1,1110 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! VFIO GPU passthrough lifecycle management for `OpenShell` VM sandboxes. +//! +//! Provides discovery, binding, and crash-recovery for NVIDIA GPUs using +//! the VFIO subsystem. All sysfs access goes through [`SysfsRoot`] so the +//! entire stack is testable without root or real hardware. + +use std::fs; +use std::path::{Path, PathBuf}; + +const NVIDIA_VENDOR_ID: &str = "0x10de"; +const GPU_CLASS_DISPLAY_VGA: &str = "0x030000"; +const GPU_CLASS_DISPLAY_3D: u32 = 0x0302; + +// --------------------------------------------------------------------------- +// Errors +// --------------------------------------------------------------------------- + +#[derive(Debug, thiserror::Error)] +pub enum VfioError { + #[error("GPU {bdf} not found in sysfs")] + GpuNotFound { bdf: String }, + + #[error("GPU {bdf} is not an NVIDIA device (vendor={vendor})")] + NotNvidia { bdf: String, vendor: String }, + + #[error("GPU {bdf} has no IOMMU group — is IOMMU enabled?")] + NoIommuGroup { bdf: String }, + + #[error("GPU {bdf} IOMMU group {group} has other non-vfio-pci devices: {peers:?}")] + IommuGroupConflict { + bdf: String, + group: u32, + peers: Vec, + }, + + #[error("failed to bind GPU {bdf} to vfio-pci: {reason}")] + BindFailed { bdf: String, reason: String }, + + #[error("failed to unbind GPU {bdf} from vfio-pci: {reason}")] + UnbindFailed { bdf: String, reason: String }, + + #[error("sysfs I/O error for {path}: {source}")] + SysfsIo { + path: String, + #[source] + source: std::io::Error, + }, + + #[error("invalid PCI BDF address: {bdf}")] + InvalidBdf { bdf: String }, +} + +// --------------------------------------------------------------------------- +// SysfsRoot +// --------------------------------------------------------------------------- + +/// Abstraction over sysfs paths, enabling test mocks via a temporary directory. +#[derive(Debug, Clone)] +pub struct SysfsRoot { + base: PathBuf, +} + +impl SysfsRoot { + /// Production root pointing at the real `/sys` filesystem. + pub fn system() -> Self { + Self { + base: PathBuf::from("/sys"), + } + } + + /// Custom root for testing. + pub fn new(base: impl Into) -> Self { + Self { base: base.into() } + } + + pub fn pci_devices_dir(&self) -> PathBuf { + self.base.join("bus/pci/devices") + } + + pub fn pci_device(&self, bdf: &str) -> PathBuf { + self.pci_devices_dir().join(bdf) + } + + pub fn drivers_probe(&self) -> PathBuf { + self.base.join("bus/pci/drivers_probe") + } + + pub fn iommu_group(&self, bdf: &str) -> Result { + let link = self.pci_device(bdf).join("iommu_group"); + let target = fs::read_link(&link).map_err(|_| VfioError::NoIommuGroup { + bdf: bdf.to_string(), + })?; + let group_str = + target + .file_name() + .and_then(|n| n.to_str()) + .ok_or_else(|| VfioError::NoIommuGroup { + bdf: bdf.to_string(), + })?; + group_str + .parse::() + .map_err(|_| VfioError::NoIommuGroup { + bdf: bdf.to_string(), + }) + } + + /// Enumerate all PCI BDFs in the given IOMMU group. + pub fn iommu_group_devices(&self, group_id: u32) -> Result, VfioError> { + let group_dir = self + .base + .join(format!("kernel/iommu_groups/{group_id}/devices")); + let entries = fs::read_dir(&group_dir).map_err(|source| VfioError::SysfsIo { + path: group_dir.display().to_string(), + source, + })?; + let mut devices = Vec::new(); + for entry in entries.filter_map(Result::ok) { + devices.push(entry.file_name().to_string_lossy().into_owned()); + } + devices.sort(); + Ok(devices) + } +} + +// --------------------------------------------------------------------------- +// GpuInfo +// --------------------------------------------------------------------------- + +/// Information about a discovered NVIDIA GPU eligible for VFIO passthrough. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct GpuInfo { + pub bdf: String, + pub name: String, + pub vendor: String, + pub device: String, + pub iommu_group: u32, +} + +// --------------------------------------------------------------------------- +// GpuBindGuard +// --------------------------------------------------------------------------- + +/// RAII guard that restores a GPU to its host driver when dropped. +/// +/// Call [`disarm`](Self::disarm) to transfer ownership (e.g. the VM took over +/// the device successfully and we should not unbind it on cleanup). +pub struct GpuBindGuard { + bdf: String, + companion_bdfs: Vec, + sysfs: SysfsRoot, + disarmed: bool, +} + +impl GpuBindGuard { + pub fn bdf(&self) -> &str { + &self.bdf + } + + /// Prevent the guard from restoring the GPU on drop. + pub fn disarm(mut self) { + self.disarmed = true; + } +} + +impl Drop for GpuBindGuard { + fn drop(&mut self) { + if self.disarmed { + return; + } + for peer in &self.companion_bdfs { + if let Err(err) = restore_gpu_to_host_driver(&self.sysfs, peer) { + tracing::error!(bdf = %peer, error = %err, "failed to restore companion device to host driver on drop"); + } + } + if let Err(err) = restore_gpu_to_host_driver(&self.sysfs, &self.bdf) { + tracing::error!(bdf = %self.bdf, error = %err, "failed to restore GPU to host driver on drop"); + } + } +} + +// --------------------------------------------------------------------------- +// GpuBindState (crash-recovery persistence) +// --------------------------------------------------------------------------- + +/// Persisted record of GPUs currently bound to vfio-pci, for crash recovery. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct GpuBindState { + pub bindings: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct GpuBinding { + pub bdf: String, + pub sandbox_id: String, + pub bound_at_ms: i64, +} + +impl GpuBindState { + pub fn load(path: &Path) -> Result { + let data = fs::read_to_string(path)?; + serde_json::from_str(&data) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e)) + } + + pub fn save(&self, path: &Path) -> Result<(), std::io::Error> { + let data = serde_json::to_string_pretty(self) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let tmp = path.with_extension("tmp"); + fs::write(&tmp, &data)?; + fs::rename(&tmp, path) + } +} + +// --------------------------------------------------------------------------- +// Validation helpers +// --------------------------------------------------------------------------- + +/// Validate a PCI BDF address (format `DDDD:BB:DD.F`). +pub fn validate_bdf(bdf: &str) -> Result<(), VfioError> { + let bytes = bdf.as_bytes(); + if bytes.len() != 12 { + return Err(VfioError::InvalidBdf { + bdf: bdf.to_string(), + }); + } + + // Expected layout: [hex×4]:[hex×2]:[hex×2].[hex×1] + // 0123 4 56 7 89 A B + let ok = is_hex(bytes[0]) + && is_hex(bytes[1]) + && is_hex(bytes[2]) + && is_hex(bytes[3]) + && bytes[4] == b':' + && is_hex(bytes[5]) + && is_hex(bytes[6]) + && bytes[7] == b':' + && is_hex(bytes[8]) + && is_hex(bytes[9]) + && bytes[10] == b'.' + && is_hex(bytes[11]); + + if ok { + Ok(()) + } else { + Err(VfioError::InvalidBdf { + bdf: bdf.to_string(), + }) + } +} + +fn is_hex(b: u8) -> bool { + b.is_ascii_hexdigit() +} + +/// Returns `true` if `data` contains only safe characters for sysfs values +/// (alphanumeric plus `:`, `.`, `-`, `_`). +pub fn validate_sysfs_data(data: &str) -> bool { + !data.is_empty() + && data + .chars() + .all(|c| c.is_alphanumeric() || matches!(c, ':' | '.' | '-' | '_')) +} + +// --------------------------------------------------------------------------- +// Sysfs helpers +// --------------------------------------------------------------------------- + +fn read_sysfs_trimmed(path: &Path) -> Result { + fs::read_to_string(path) + .map(|s| s.trim().to_string()) + .map_err(|source| VfioError::SysfsIo { + path: path.display().to_string(), + source, + }) +} + +fn write_sysfs(path: &Path, value: &str) -> Result<(), VfioError> { + fs::write(path, value).map_err(|source| VfioError::SysfsIo { + path: path.display().to_string(), + source, + }) +} + +fn current_driver_name(sysfs: &SysfsRoot, bdf: &str) -> Option { + let driver_link = sysfs.pci_device(bdf).join("driver"); + fs::read_link(&driver_link) + .ok() + .and_then(|p| p.file_name().map(|n| n.to_string_lossy().into_owned())) +} + +fn is_gpu_class(class_str: &str) -> bool { + if class_str == GPU_CLASS_DISPLAY_VGA { + return true; + } + // 3D controller: 0x0302xx + if let Some(hex) = class_str.strip_prefix("0x") + && let Ok(val) = u32::from_str_radix(hex, 16) + { + return (val >> 8) == GPU_CLASS_DISPLAY_3D; + } + false +} + +// --------------------------------------------------------------------------- +// Discovery +// --------------------------------------------------------------------------- + +/// Scan sysfs for NVIDIA GPUs eligible for VFIO passthrough. +pub fn probe_host_nvidia_vfio_readiness(sysfs: &SysfsRoot) -> Vec { + let devices_dir = sysfs.pci_devices_dir(); + let entries = match fs::read_dir(&devices_dir) { + Ok(e) => e, + Err(err) => { + tracing::warn!(path = %devices_dir.display(), %err, "cannot read PCI devices directory"); + return Vec::new(); + } + }; + + let mut gpus = Vec::new(); + + for entry in entries.filter_map(Result::ok) { + let bdf = entry.file_name().to_string_lossy().into_owned(); + let dev_dir = sysfs.pci_device(&bdf); + + let Ok(vendor) = read_sysfs_trimmed(&dev_dir.join("vendor")) else { + continue; + }; + if vendor != NVIDIA_VENDOR_ID { + continue; + } + + let Ok(class) = read_sysfs_trimmed(&dev_dir.join("class")) else { + continue; + }; + if !is_gpu_class(&class) { + continue; + } + + let device = read_sysfs_trimmed(&dev_dir.join("device")).unwrap_or_default(); + + let name = read_sysfs_trimmed(&dev_dir.join("label")) + .unwrap_or_else(|_| format!("NVIDIA {device}")); + + let Ok(iommu_group) = sysfs.iommu_group(&bdf) else { + continue; + }; + + gpus.push(GpuInfo { + bdf, + name, + vendor, + device, + iommu_group, + }); + } + + gpus +} + +// --------------------------------------------------------------------------- +// Bind / unbind +// --------------------------------------------------------------------------- + +/// Bind a single PCI device to `vfio-pci`. Skips devices already bound. +fn bind_device_to_vfio(sysfs: &SysfsRoot, bdf: &str) -> Result { + if let Some(drv) = current_driver_name(sysfs, bdf) { + if drv == "vfio-pci" { + return Ok(false); + } + let unbind_path = sysfs.pci_device(bdf).join("driver/unbind"); + write_sysfs(&unbind_path, bdf).map_err(|e| VfioError::BindFailed { + bdf: bdf.to_string(), + reason: format!("unbind from {drv}: {e}"), + })?; + tracing::info!(bdf, driver = %drv, "unbound device from current driver"); + } + + let override_path = sysfs.pci_device(bdf).join("driver_override"); + write_sysfs(&override_path, "vfio-pci").map_err(|e| VfioError::BindFailed { + bdf: bdf.to_string(), + reason: format!("driver_override: {e}"), + })?; + + write_sysfs(&sysfs.drivers_probe(), bdf).map_err(|e| VfioError::BindFailed { + bdf: bdf.to_string(), + reason: format!("drivers_probe: {e}"), + })?; + + match current_driver_name(sysfs, bdf) { + Some(ref drv) if drv == "vfio-pci" => {} + other => { + return Err(VfioError::BindFailed { + bdf: bdf.to_string(), + reason: format!( + "after probe, driver is {:?} instead of vfio-pci", + other.as_deref().unwrap_or("") + ), + }); + } + } + + Ok(true) +} + +/// Bind a GPU to `vfio-pci`, returning an RAII guard that restores it on drop. +/// +/// Also binds all companion devices in the same IOMMU group (e.g. the +/// HD Audio function on consumer GPUs). All bound companions are tracked +/// and restored when the guard is dropped. +pub fn prepare_gpu_for_passthrough( + sysfs: &SysfsRoot, + bdf: &str, +) -> Result { + validate_bdf(bdf)?; + + let dev_dir = sysfs.pci_device(bdf); + if !dev_dir.exists() { + return Err(VfioError::GpuNotFound { + bdf: bdf.to_string(), + }); + } + + let vendor = read_sysfs_trimmed(&dev_dir.join("vendor"))?; + if vendor != NVIDIA_VENDOR_ID { + return Err(VfioError::NotNvidia { + bdf: bdf.to_string(), + vendor, + }); + } + + let iommu_group = sysfs.iommu_group(bdf)?; + let group_devices = sysfs.iommu_group_devices(iommu_group)?; + let peers: Vec = group_devices.into_iter().filter(|d| d != bdf).collect(); + + let mut bound_companions = Vec::new(); + for peer in &peers { + if !sysfs.pci_device(peer).exists() { + continue; + } + match bind_device_to_vfio(sysfs, peer) { + Ok(was_bound) => { + if was_bound { + tracing::info!(bdf = %peer, iommu_group, "bound IOMMU group companion to vfio-pci"); + bound_companions.push(peer.clone()); + } + } + Err(err) => { + for already_bound in bound_companions.iter().rev() { + if let Err(restore_err) = restore_gpu_to_host_driver(sysfs, already_bound) { + tracing::error!(bdf = %already_bound, error = %restore_err, "failed to restore companion during rollback"); + } + } + return Err(VfioError::BindFailed { + bdf: peer.clone(), + reason: format!("IOMMU group {iommu_group} companion bind failed: {err}"), + }); + } + } + } + + match bind_device_to_vfio(sysfs, bdf) { + Ok(was_bound) => { + if was_bound { + tracing::info!(bdf, "GPU bound to vfio-pci"); + } else { + tracing::info!(bdf, "GPU already bound to vfio-pci"); + } + } + Err(err) => { + for companion in bound_companions.iter().rev() { + if let Err(restore_err) = restore_gpu_to_host_driver(sysfs, companion) { + tracing::error!(bdf = %companion, error = %restore_err, "failed to restore companion during rollback"); + } + } + return Err(err); + } + } + + Ok(GpuBindGuard { + bdf: bdf.to_string(), + companion_bdfs: bound_companions, + sysfs: sysfs.clone(), + disarmed: false, + }) +} + +/// Restore a GPU from `vfio-pci` back to the host's default driver. +fn restore_gpu_to_host_driver(sysfs: &SysfsRoot, bdf: &str) -> Result<(), VfioError> { + let dev_dir = sysfs.pci_device(bdf); + + let unbind_path = dev_dir.join("driver/unbind"); + if unbind_path.exists() { + write_sysfs(&unbind_path, bdf).map_err(|e| VfioError::UnbindFailed { + bdf: bdf.to_string(), + reason: format!("unbind: {e}"), + })?; + } + + let override_path = dev_dir.join("driver_override"); + if override_path.exists() { + write_sysfs(&override_path, "\n").map_err(|e| VfioError::UnbindFailed { + bdf: bdf.to_string(), + reason: format!("clear driver_override: {e}"), + })?; + } + + let probe = sysfs.drivers_probe(); + if probe.exists() { + write_sysfs(&probe, bdf).map_err(|e| VfioError::UnbindFailed { + bdf: bdf.to_string(), + reason: format!("drivers_probe: {e}"), + })?; + } + + tracing::info!(bdf, "GPU restored to host driver"); + Ok(()) +} + +// --------------------------------------------------------------------------- +// Crash-recovery reconciliation +// --------------------------------------------------------------------------- + +/// Reconcile stale VFIO bindings left over from a previous crash. +/// +/// Loads persisted state, checks each GPU, and restores any that are still +/// bound to `vfio-pci`. Returns the list of BDFs that were restored. +/// Removes the state file after reconciliation. +pub fn reconcile_stale_bindings(sysfs: &SysfsRoot, state_path: &Path) -> Vec { + let state = match GpuBindState::load(state_path) { + Ok(s) => s, + Err(err) => { + tracing::debug!(%err, path = %state_path.display(), "no stale GPU bind state to reconcile"); + return Vec::new(); + } + }; + + let mut restored = Vec::new(); + + for binding in &state.bindings { + match current_driver_name(sysfs, &binding.bdf) { + Some(ref drv) if drv == "vfio-pci" => { + tracing::warn!( + bdf = %binding.bdf, + sandbox_id = %binding.sandbox_id, + "stale VFIO binding detected, restoring GPU to host driver" + ); + if let Err(err) = restore_gpu_to_host_driver(sysfs, &binding.bdf) { + tracing::error!(bdf = %binding.bdf, %err, "failed to restore stale GPU binding"); + continue; + } + restored.push(binding.bdf.clone()); + } + _ => { + let override_path = sysfs.pci_device(&binding.bdf).join("driver_override"); + if let Ok(val) = read_sysfs_trimmed(&override_path) + && val == "vfio-pci" + { + tracing::warn!( + bdf = %binding.bdf, + sandbox_id = %binding.sandbox_id, + "stale driver_override detected, clearing and re-probing" + ); + if let Err(err) = write_sysfs(&override_path, "\n") { + tracing::error!(bdf = %binding.bdf, %err, "failed to clear stale driver_override"); + continue; + } + let probe = sysfs.drivers_probe(); + if let Err(err) = write_sysfs(&probe, &binding.bdf) { + tracing::error!(bdf = %binding.bdf, %err, "failed to re-probe after clearing driver_override"); + } + restored.push(binding.bdf.clone()); + } else { + tracing::debug!(bdf = %binding.bdf, "GPU no longer bound to vfio-pci, skipping"); + } + } + } + } + + if let Err(err) = fs::remove_file(state_path) { + tracing::warn!(%err, path = %state_path.display(), "failed to remove stale bind state file"); + } + + restored +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::os::unix::fs::symlink; + use tempfile::TempDir; + + fn setup_mock_sysfs() -> (TempDir, SysfsRoot) { + let tmp = TempDir::new().unwrap(); + let sysfs = SysfsRoot::new(tmp.path()); + (tmp, sysfs) + } + + fn create_pci_device( + sysfs: &SysfsRoot, + tmp: &Path, + bdf: &str, + vendor: &str, + device: &str, + class: &str, + iommu_group: u32, + ) { + let dev = sysfs.pci_device(bdf); + fs::create_dir_all(&dev).unwrap(); + + fs::write(dev.join("vendor"), format!("{vendor}\n")).unwrap(); + fs::write(dev.join("device"), format!("{device}\n")).unwrap(); + fs::write(dev.join("class"), format!("{class}\n")).unwrap(); + + let group_dir = tmp.join(format!("kernel/iommu_groups/{iommu_group}")); + fs::create_dir_all(&group_dir).unwrap(); + symlink(&group_dir, dev.join("iommu_group")).unwrap(); + + let group_devices_dir = group_dir.join("devices"); + fs::create_dir_all(&group_devices_dir).unwrap(); + symlink(&dev, group_devices_dir.join(bdf)).unwrap(); + } + + // -- validate_bdf ------------------------------------------------------- + + #[test] + fn test_validate_bdf_valid() { + assert!(validate_bdf("0000:2d:00.0").is_ok()); + assert!(validate_bdf("0000:00:00.0").is_ok()); + assert!(validate_bdf("abcd:ef:01.a").is_ok()); + assert!(validate_bdf("ABCD:EF:01.A").is_ok()); + } + + #[test] + fn test_validate_bdf_invalid() { + assert!(validate_bdf("").is_err()); + assert!(validate_bdf("0000:2d:00").is_err()); // too short + assert!(validate_bdf("0000:2d:00.00").is_err()); // too long + assert!(validate_bdf("000g:2d:00.0").is_err()); // non-hex + assert!(validate_bdf("0000-2d-00.0").is_err()); // wrong separators + assert!(validate_bdf("0000:2d:00:0").is_err()); // colon instead of dot + } + + #[test] + fn test_validate_bdf_rejects_metacharacters() { + assert!(validate_bdf("$(rm -rf /)").is_err()); + assert!(validate_bdf("; echo pwned").is_err()); + assert!(validate_bdf("0000:2d;00.0").is_err()); + assert!(validate_bdf("0000:2d:0`.0").is_err()); + assert!(validate_bdf("../../../../").is_err()); + } + + // -- validate_sysfs_data ------------------------------------------------ + + #[test] + fn test_validate_sysfs_data() { + assert!(validate_sysfs_data("0x10de")); + assert!(validate_sysfs_data("vfio-pci")); + assert!(validate_sysfs_data("nvidia_gpu_0")); + assert!(validate_sysfs_data("0000:2d:00.0")); + + assert!(!validate_sysfs_data("")); + assert!(!validate_sysfs_data("$(echo)")); + assert!(!validate_sysfs_data("a b")); + assert!(!validate_sysfs_data("foo;bar")); + assert!(!validate_sysfs_data("a\nb")); + } + + // -- probe_host_nvidia_vfio_readiness ----------------------------------- + + #[test] + fn test_probe_discovers_nvidia_gpu() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 42, + ); + + let gpus = probe_host_nvidia_vfio_readiness(&sysfs); + assert_eq!(gpus.len(), 1); + assert_eq!(gpus[0].bdf, "0000:2d:00.0"); + assert_eq!(gpus[0].vendor, "0x10de"); + assert_eq!(gpus[0].device, "0x2684"); + assert_eq!(gpus[0].iommu_group, 42); + } + + #[test] + fn test_probe_skips_non_nvidia() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:01:00.0", + "0x8086", + "0x1234", + "0x030000", + 10, + ); + + let gpus = probe_host_nvidia_vfio_readiness(&sysfs); + assert!(gpus.is_empty()); + } + + #[test] + fn test_probe_skips_non_gpu_nvidia() { + let (tmp, sysfs) = setup_mock_sysfs(); + // Audio device class 0x040300 + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.1", + "0x10de", + "0x228b", + "0x040300", + 42, + ); + + let gpus = probe_host_nvidia_vfio_readiness(&sysfs); + assert!(gpus.is_empty()); + } + + // -- GpuBindState ------------------------------------------------------- + + #[test] + fn test_gpu_bind_state_roundtrip() { + let tmp = TempDir::new().unwrap(); + let path = tmp.path().join("gpu-state.json"); + + let state = GpuBindState { + bindings: vec![ + GpuBinding { + bdf: "0000:2d:00.0".to_string(), + sandbox_id: "sandbox-123".to_string(), + bound_at_ms: 1_700_000_000_000, + }, + GpuBinding { + bdf: "0000:3b:00.0".to_string(), + sandbox_id: "sandbox-456".to_string(), + bound_at_ms: 1_700_000_001_000, + }, + ], + }; + + state.save(&path).unwrap(); + let loaded = GpuBindState::load(&path).unwrap(); + + assert_eq!(loaded.bindings.len(), 2); + assert_eq!(loaded.bindings[0].bdf, "0000:2d:00.0"); + assert_eq!(loaded.bindings[0].sandbox_id, "sandbox-123"); + assert_eq!(loaded.bindings[1].bdf, "0000:3b:00.0"); + } + + // -- SysfsRoot ---------------------------------------------------------- + + #[test] + fn test_sysfs_root_paths() { + let sysfs = SysfsRoot::system(); + assert_eq!( + sysfs.pci_device("0000:2d:00.0"), + PathBuf::from("/sys/bus/pci/devices/0000:2d:00.0") + ); + assert_eq!( + sysfs.pci_devices_dir(), + PathBuf::from("/sys/bus/pci/devices") + ); + assert_eq!( + sysfs.drivers_probe(), + PathBuf::from("/sys/bus/pci/drivers_probe") + ); + + let custom = SysfsRoot::new("/tmp/test-sys"); + assert_eq!( + custom.pci_device("0000:01:00.0"), + PathBuf::from("/tmp/test-sys/bus/pci/devices/0000:01:00.0") + ); + } + + #[test] + fn test_sysfs_root_iommu_group() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 42, + ); + + assert_eq!(sysfs.iommu_group("0000:2d:00.0").unwrap(), 42); + assert!(sysfs.iommu_group("0000:ff:ff.f").is_err()); + } + + // -- is_gpu_class ------------------------------------------------------- + + #[test] + fn test_is_gpu_class() { + assert!(is_gpu_class("0x030000")); + assert!(is_gpu_class("0x030200")); + assert!(is_gpu_class("0x030201")); + assert!(!is_gpu_class("0x040300")); + assert!(!is_gpu_class("0x060000")); + assert!(!is_gpu_class("")); + } + + // -- iommu_group_devices ------------------------------------------------ + + #[test] + fn test_iommu_group_devices_lists_all_members() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 42, + ); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.1", + "0x10de", + "0x228b", + "0x040300", + 42, + ); + + let devices = sysfs.iommu_group_devices(42).unwrap(); + assert_eq!(devices.len(), 2); + assert!(devices.contains(&"0000:2d:00.0".to_string())); + assert!(devices.contains(&"0000:2d:00.1".to_string())); + } + + #[test] + fn test_iommu_group_devices_single_device() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 99, + ); + + let devices = sysfs.iommu_group_devices(99).unwrap(); + assert_eq!(devices.len(), 1); + assert_eq!(devices[0], "0000:2d:00.0"); + } + + // -- companion binding -------------------------------------------------- + + /// Helper to create a fake driver symlink for a mock PCI device. + fn set_mock_driver(sysfs: &SysfsRoot, bdf: &str, driver_name: &str) { + let driver_dir = sysfs.base.join(format!("bus/pci/drivers/{driver_name}")); + fs::create_dir_all(&driver_dir).unwrap(); + let dev_driver_link = sysfs.pci_device(bdf).join("driver"); + let _ = fs::remove_file(&dev_driver_link); + symlink(&driver_dir, &dev_driver_link).unwrap(); + } + + #[test] + fn test_prepare_gpu_skips_already_bound_companions() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 42, + ); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.1", + "0x10de", + "0x228b", + "0x040300", + 42, + ); + + let probe = sysfs.drivers_probe(); + fs::create_dir_all(probe.parent().unwrap()).unwrap(); + fs::write(&probe, "").unwrap(); + + // Both devices already on vfio-pci + set_mock_driver(&sysfs, "0000:2d:00.0", "vfio-pci"); + set_mock_driver(&sysfs, "0000:2d:00.1", "vfio-pci"); + + let guard = prepare_gpu_for_passthrough(&sysfs, "0000:2d:00.0").unwrap(); + + // Both were already bound, no companions should be tracked for restore + assert!(guard.companion_bdfs.is_empty()); + assert_eq!(guard.bdf, "0000:2d:00.0"); + } + + #[test] + fn test_prepare_gpu_solo_iommu_group_no_companions() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 99, + ); + + let probe = sysfs.drivers_probe(); + fs::create_dir_all(probe.parent().unwrap()).unwrap(); + fs::write(&probe, "").unwrap(); + + // GPU already on vfio-pci + set_mock_driver(&sysfs, "0000:2d:00.0", "vfio-pci"); + + let guard = prepare_gpu_for_passthrough(&sysfs, "0000:2d:00.0").unwrap(); + assert!(guard.companion_bdfs.is_empty()); + } + + #[test] + fn test_bind_device_to_vfio_already_bound() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 42, + ); + + let probe = sysfs.drivers_probe(); + fs::create_dir_all(probe.parent().unwrap()).unwrap(); + fs::write(&probe, "").unwrap(); + + set_mock_driver(&sysfs, "0000:2d:00.0", "vfio-pci"); + + let was_bound = bind_device_to_vfio(&sysfs, "0000:2d:00.0").unwrap(); + assert!(!was_bound, "should report false when already on vfio-pci"); + } + + #[test] + fn test_guard_drop_restores_companions() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 42, + ); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.1", + "0x10de", + "0x228b", + "0x040300", + 42, + ); + + let probe = sysfs.drivers_probe(); + fs::create_dir_all(probe.parent().unwrap()).unwrap(); + fs::write(&probe, "").unwrap(); + + // Simulate bound state: driver link and driver_override both set + set_mock_driver(&sysfs, "0000:2d:00.0", "vfio-pci"); + set_mock_driver(&sysfs, "0000:2d:00.1", "vfio-pci"); + fs::write( + sysfs.pci_device("0000:2d:00.0").join("driver_override"), + "vfio-pci", + ) + .unwrap(); + fs::write( + sysfs.pci_device("0000:2d:00.1").join("driver_override"), + "vfio-pci", + ) + .unwrap(); + + { + let _guard = GpuBindGuard { + bdf: "0000:2d:00.0".to_string(), + companion_bdfs: vec!["0000:2d:00.1".to_string()], + sysfs: sysfs.clone(), + disarmed: false, + }; + // guard drops here — should attempt restore on both devices + } + + // After drop, driver_override should be cleared (written with "\n") + let gpu_override = + fs::read_to_string(sysfs.pci_device("0000:2d:00.0").join("driver_override")).unwrap(); + assert_eq!( + gpu_override.trim(), + "", + "GPU driver_override should be cleared after drop" + ); + + let companion_override = + fs::read_to_string(sysfs.pci_device("0000:2d:00.1").join("driver_override")).unwrap(); + assert_eq!( + companion_override.trim(), + "", + "companion driver_override should be cleared after drop" + ); + } + + #[test] + fn test_guard_disarm_skips_restore() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 42, + ); + + // Write a non-empty driver_override to verify it's NOT cleared + fs::write( + sysfs.pci_device("0000:2d:00.0").join("driver_override"), + "vfio-pci", + ) + .unwrap(); + + let guard = GpuBindGuard { + bdf: "0000:2d:00.0".to_string(), + companion_bdfs: vec![], + sysfs: sysfs.clone(), + disarmed: false, + }; + guard.disarm(); + + // driver_override should still be vfio-pci (not cleared by disarmed guard) + let override_val = + fs::read_to_string(sysfs.pci_device("0000:2d:00.0").join("driver_override")).unwrap(); + assert_eq!(override_val, "vfio-pci"); + } + + // -- reconcile_stale_bindings ------------------------------------------- + + #[test] + fn test_reconcile_clears_stale_driver_override_when_not_on_vfio() { + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x2684", + "0x030000", + 42, + ); + + let probe = sysfs.drivers_probe(); + fs::create_dir_all(probe.parent().unwrap()).unwrap(); + fs::write(&probe, "").unwrap(); + + set_mock_driver(&sysfs, "0000:2d:00.0", "nvidia"); + fs::write( + sysfs.pci_device("0000:2d:00.0").join("driver_override"), + "vfio-pci", + ) + .unwrap(); + + let state_path = tmp.path().join("gpu-state.json"); + let state = GpuBindState { + bindings: vec![GpuBinding { + bdf: "0000:2d:00.0".to_string(), + sandbox_id: "sandbox-orphan".to_string(), + bound_at_ms: 0, + }], + }; + state.save(&state_path).unwrap(); + + let restored = reconcile_stale_bindings(&sysfs, &state_path); + assert!(restored.contains(&"0000:2d:00.0".to_string())); + + let override_val = + fs::read_to_string(sysfs.pci_device("0000:2d:00.0").join("driver_override")).unwrap(); + assert_eq!( + override_val.trim(), + "", + "driver_override should be cleared even when device is not on vfio-pci" + ); + } +} diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 68af695e5..3c4308f3f 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -53,6 +53,8 @@ message GetCapabilitiesResponse { string default_image = 3; // True when the driver can provision GPU-backed sandboxes. bool supports_gpu = 4; + // Number of GPUs available for sandbox assignment. + uint32 gpu_count = 5; } // Driver-owned sandbox model used for create requests and platform observations. @@ -84,6 +86,10 @@ message DriverSandboxSpec { DriverSandboxTemplate template = 6; // Request NVIDIA GPU resources for this sandbox. bool gpu = 9; + // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index + // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the + // first available GPU. + string gpu_device = 10; } // Driver-owned runtime template consumed by the compute platform. diff --git a/proto/openshell.proto b/proto/openshell.proto index 1d7eba218..75490f338 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -204,6 +204,10 @@ message SandboxSpec { repeated string providers = 8; // Request NVIDIA GPU resources for this sandbox. bool gpu = 9; + // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index + // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the + // first available GPU. + string gpu_device = 10; } // Public sandbox template mapped onto compute-driver template inputs. diff --git a/tasks/vm.toml b/tasks/vm.toml index 0a44b4ff7..2549f230f 100644 --- a/tasks/vm.toml +++ b/tasks/vm.toml @@ -17,7 +17,7 @@ # ═══════════════════════════════════════════════════════════════════════════ ["gateway:vm"] -description = "Build openshell-gateway + openshell-driver-vm and start the gateway with the VM driver" +description = "Build openshell-gateway + openshell-driver-vm and start the gateway with the VM driver (pass -- --gpu for GPU support)" run = "crates/openshell-driver-vm/start.sh" [vm] From 38f069e68a1789f0dae2e14b7f1956655f48fb3a Mon Sep 17 00:00:00 2001 From: Vincent Caux-Brisebois Date: Wed, 29 Apr 2026 00:10:20 +0000 Subject: [PATCH 2/2] Add GPU rootfs variant, harden VFIO binding, and fix networking and supervisor reliability issues discovered during GPU VM bring-up. Signed-off-by: Vincent Caux-Brisebois --- architecture/podman-rootless-networking.md | 16 +- crates/openshell-driver-docker/src/lib.rs | 1 + crates/openshell-driver-docker/src/tests.rs | 1 + crates/openshell-driver-vm/build.rs | 20 +- .../scripts/openshell-vm-sandbox-init.sh | 110 +++- crates/openshell-driver-vm/src/driver.rs | 87 ++- crates/openshell-driver-vm/src/lib.rs | 5 +- crates/openshell-driver-vm/src/main.rs | 6 + crates/openshell-driver-vm/src/rootfs.rs | 45 +- crates/openshell-driver-vm/src/runtime.rs | 124 ++++- crates/openshell-driver-vm/start.sh | 39 +- crates/openshell-sandbox/src/lib.rs | 70 ++- crates/openshell-server/src/compute/mod.rs | 1 + crates/openshell-vfio/src/lib.rs | 521 +++++++++++++++++- mise.lock | 112 ++++ tasks/scripts/vm/build-rootfs-tarball.sh | 319 ++++++++++- tasks/scripts/vm/compress-vm-runtime.sh | 4 +- 17 files changed, 1355 insertions(+), 126 deletions(-) diff --git a/architecture/podman-rootless-networking.md b/architecture/podman-rootless-networking.md index b267cfffa..d13b9ca84 100644 --- a/architecture/podman-rootless-networking.md +++ b/architecture/podman-rootless-networking.md @@ -35,7 +35,7 @@ For rootful bridge networking: 6. Netavark configures iptables/nftables rules -- masquerade for outbound, DNAT for port mappings 7. Netavark starts aardvark-dns if DNS is enabled, listening on the bridge gateway address -``` +```text Host Kernel | +-- Bridge interface (e.g., "podman0") <-- created by Netavark @@ -60,7 +60,7 @@ Unprivileged users cannot create network interfaces on the host. They cannot cre Pasta (part of the `passt` project -- same binary, different command name) operates entirely in userspace, translating between the container's L2 TAP interface and the host's L4 sockets. It requires no capabilities or privileges. -``` +```text Container Network Namespace | +-- TAP device (e.g., "eth0") @@ -131,7 +131,7 @@ Unlike bridge networking, pasta containers are isolated from each other by defau The Podman compute driver creates three layers of network isolation: -``` +```text Namespace 1: Host | pasta manages port forwarding (127.0.0.1:) @@ -164,7 +164,7 @@ client.ensure_network(&config.network_name).await?; This creates a bridge network named `"openshell"` (default from `DEFAULT_NETWORK_NAME` in `openshell-core/src/config.rs`) with `dns_enabled: true`. In rootless mode, this bridge exists inside a user namespace managed by pasta. The bridge IP range (e.g., `10.89.x.x`) is not routable from the host. -``` +```text Host (your machine) | 127.0.0.1: <--- pasta binds this on the host @@ -212,7 +212,7 @@ The bridge gateway IP does NOT work for this purpose in rootless mode because it Inside the container, the supervisor creates another network namespace (`netns.rs:53-178`, setup at lines 53-63, `ip netns add` at line 77) for the user workload: -``` +```text Container (10.89.1.2 on the Podman bridge) | [Supervisor process - runs in container's default netns] @@ -247,7 +247,7 @@ A tmpfs is mounted at `/run/netns` in the container spec (`container.rs:458-463` ### SSH Session: Client to Sandbox Shell -``` +```text Client (CLI on user's machine) | 1. gRPC: CreateSshSession -> gateway (returns token, connect_path) @@ -281,7 +281,7 @@ The SSH daemon listens on a Unix socket (not a TCP port) with 0600 permissions. ### Outbound HTTP Request from Sandbox Process -``` +```text User's code (inner netns, 10.200.0.2) | 1. curl https://api.example.com @@ -306,7 +306,7 @@ Supervisor proxy (10.200.0.1:3128 in container netns) ### Supervisor gRPC Callback to Gateway -``` +```text Supervisor (container netns, 10.89.x.2) | 1. gRPC connect to http://host.containers.internal:8080 diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index b4371ce7d..d9683d416 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -231,6 +231,7 @@ impl DockerComputeDriver { driver_version: self.config.daemon_version.clone(), default_image: self.config.default_image.clone(), supports_gpu: false, + gpu_count: 0, } } diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index b20fbf5ce..cf9632c40 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -30,6 +30,7 @@ fn test_sandbox() -> DriverSandbox { platform_config: None, }), gpu: false, + gpu_device: String::new(), }), status: None, } diff --git a/crates/openshell-driver-vm/build.rs b/crates/openshell-driver-vm/build.rs index 174a90fc8..981cf8ff8 100644 --- a/crates/openshell-driver-vm/build.rs +++ b/crates/openshell-driver-vm/build.rs @@ -22,6 +22,7 @@ fn main() { "libkrunfw.5.dylib.zst", "gvproxy.zst", "rootfs.tar.zst", + "rootfs-gpu.tar.zst", ] { println!("cargo:rerun-if-changed={dir}/{name}"); } @@ -36,7 +37,15 @@ fn main() { "linux" => ("libkrun.so", "libkrunfw.so.5"), _ => { println!("cargo:warning=VM runtime not available for {target_os}-{target_arch}"); - generate_stub_resources(&out_dir, &["libkrun", "libkrunfw", "rootfs.tar.zst"]); + generate_stub_resources( + &out_dir, + &[ + "libkrun", + "libkrunfw", + "rootfs.tar.zst", + "rootfs-gpu.tar.zst", + ], + ); return; } }; @@ -53,6 +62,7 @@ fn main() { &format!("{libkrunfw_name}.zst"), "gvproxy.zst", "rootfs.tar.zst", + "rootfs-gpu.tar.zst", ], ); return; @@ -71,6 +81,7 @@ fn main() { &format!("{libkrunfw_name}.zst"), "gvproxy.zst", "rootfs.tar.zst", + "rootfs-gpu.tar.zst", ], ); return; @@ -84,6 +95,10 @@ fn main() { ), ("gvproxy.zst".to_string(), "gvproxy.zst".to_string()), ("rootfs.tar.zst".to_string(), "rootfs.tar.zst".to_string()), + ( + "rootfs-gpu.tar.zst".to_string(), + "rootfs-gpu.tar.zst".to_string(), + ), ]; let mut all_found = true; @@ -124,12 +139,13 @@ fn main() { &format!("{libkrunfw_name}.zst"), "gvproxy.zst", "rootfs.tar.zst", + "rootfs-gpu.tar.zst", ], ); } } -fn generate_stub_resources(out_dir: &PathBuf, names: &[&str]) { +fn generate_stub_resources(out_dir: &std::path::Path, names: &[&str]) { for name in names { let path = out_dir.join(name); if !path.exists() { diff --git a/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh b/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh index 5e3227d11..1c009a7f1 100644 --- a/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh +++ b/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh @@ -148,6 +148,53 @@ rewrite_openshell_endpoint_if_needed() { ts "WARNING: could not reach OpenShell endpoint ${host}:${port}" } +create_gpu_device_nodes_mknod() { + # Mode 666 is intentional: single-tenant microVM with the VM itself as the + # isolation boundary. The sandbox user is the only non-root user. + local nv_major + nv_major=$(awk '$2 == "nvidia" {print $1}' /proc/devices 2>/dev/null || true) + if [ -n "$nv_major" ]; then + mknod -m 666 /dev/nvidiactl c "$nv_major" 255 2>/dev/null || true + + local gpu_count=0 + if [ -d /proc/driver/nvidia/gpus ]; then + for gpu_dir in /proc/driver/nvidia/gpus/*/; do + [ -d "$gpu_dir" ] || continue + mknod -m 666 "/dev/nvidia${gpu_count}" c "$nv_major" "$gpu_count" 2>/dev/null || true + gpu_count=$((gpu_count + 1)) + done + fi + if [ "$gpu_count" -eq 0 ]; then + mknod -m 666 /dev/nvidia0 c "$nv_major" 0 2>/dev/null || true + fi + + local modeset_major + modeset_major=$(awk '$2 == "nvidia-modeset" {print $1}' /proc/devices 2>/dev/null || true) + if [ -n "$modeset_major" ]; then + mknod -m 666 /dev/nvidia-modeset c "$modeset_major" 254 2>/dev/null || true + fi + + local uvm_major + uvm_major=$(awk '$2 == "nvidia-uvm" {print $1}' /proc/devices 2>/dev/null || true) + if [ -n "$uvm_major" ]; then + mknod -m 666 /dev/nvidia-uvm c "$uvm_major" 0 2>/dev/null || true + mknod -m 666 /dev/nvidia-uvm-tools c "$uvm_major" 1 2>/dev/null || true + fi + + local caps_major + caps_major=$(awk '$2 == "nvidia-caps" {print $1}' /proc/devices 2>/dev/null || true) + if [ -n "$caps_major" ]; then + mkdir -p /dev/nvidia-caps 2>/dev/null || true + mknod -m 666 /dev/nvidia-caps/nvidia-cap1 c "$caps_major" 1 2>/dev/null || true + mknod -m 666 /dev/nvidia-caps/nvidia-cap2 c "$caps_major" 2 2>/dev/null || true + fi + + ts "GPU device nodes created via mknod (${gpu_count} GPU(s), major=${nv_major})" + else + ts "WARNING: 'nvidia' not in /proc/devices; device nodes unavailable" + fi +} + setup_gpu() { ts "GPU_ENABLED=true — initializing GPU passthrough" @@ -157,8 +204,6 @@ setup_gpu() { fi # Stage GSP firmware from virtiofs to tmpfs to avoid slow FUSE reads - # during module load. The kernel's firmware_class.path= cmdline param - # points here initially for early request_firmware calls. if [ -d /lib/firmware/nvidia ]; then ts "staging GPU firmware to tmpfs" mkdir -p /run/firmware/nvidia @@ -173,19 +218,22 @@ setup_gpu() { modprobe nvidia_uvm 2>/dev/null || true modprobe nvidia_modeset 2>/dev/null || true - # Free the tmpfs firmware copy now that modules are loaded rm -rf /run/firmware 2>/dev/null || true if command -v nvidia-smi >/dev/null 2>&1; then - ts "validating nvidia-smi" - if nvidia-smi; then + ts "running nvidia-smi to create device nodes and validate GPU" + local smi_rc=0 + nvidia-smi >/dev/null 2>&1 || smi_rc=$? + if [ "$smi_rc" -eq 0 ]; then + nvidia-smi -L 2>/dev/null | while read -r line; do ts " $line"; done ts "GPU initialization successful" else - ts "FATAL: nvidia-smi failed" - return 1 + ts "WARNING: nvidia-smi failed (exit ${smi_rc}); falling back to mknod" + create_gpu_device_nodes_mknod fi else - ts "WARNING: nvidia-smi not found in rootfs; skipping GPU validation" + ts "nvidia-smi not found; creating device nodes via mknod" + create_gpu_device_nodes_mknod fi } @@ -220,14 +268,28 @@ if [ -n "${VM_NET_IP}" ] && [ -n "${VM_NET_GW}" ]; then ts "configuring TAP networking (static ${VM_NET_IP} gw ${VM_NET_GW})" GATEWAY_IP="${VM_NET_GW}" - if ip link show eth0 >/dev/null 2>&1; then - ip link set eth0 up 2>/dev/null || true - ip addr add "${VM_NET_IP}/30" dev eth0 2>/dev/null || true - ip route add default via "${VM_NET_GW}" 2>/dev/null || true - elif ip link show ens3 >/dev/null 2>&1; then - ip link set ens3 up 2>/dev/null || true - ip addr add "${VM_NET_IP}/30" dev ens3 2>/dev/null || true + TAP_NIC="" + NIC_WAIT=0 + while [ -z "$TAP_NIC" ] && [ "$NIC_WAIT" -lt 10 ]; do + for candidate in eth0 ens3 enp0s2 $(ls /sys/class/net/ 2>/dev/null | grep -v '^lo$'); do + if ip link show "$candidate" >/dev/null 2>&1 && [ "$candidate" != "lo" ]; then + TAP_NIC="$candidate" + break + fi + done + if [ -z "$TAP_NIC" ]; then + sleep 1 + NIC_WAIT=$((NIC_WAIT + 1)) + fi + done + + if [ -n "$TAP_NIC" ]; then + ts "using NIC ${TAP_NIC} for TAP networking" + ip link set "$TAP_NIC" up 2>/dev/null || true + ip addr add "${VM_NET_IP}/30" dev "$TAP_NIC" 2>/dev/null || true ip route add default via "${VM_NET_GW}" 2>/dev/null || true + else + ts "WARNING: no network interface found for TAP networking" fi if [ -n "${VM_NET_DNS}" ]; then @@ -293,5 +355,23 @@ export USER=sandbox rewrite_openshell_endpoint_if_needed +# Log supervisor connectivity state for debugging stuck-in-Provisioning issues +if [ -n "${OPENSHELL_ENDPOINT:-}" ]; then + _ep_parsed="$(parse_endpoint "$OPENSHELL_ENDPOINT" 2>/dev/null || true)" + if [ -n "$_ep_parsed" ]; then + _ep_host="$(printf '%s\n' "$_ep_parsed" | sed -n '2p')" + _ep_port="$(printf '%s\n' "$_ep_parsed" | sed -n '3p')" + if tcp_probe "$_ep_host" "$_ep_port"; then + ts "gateway reachable at ${_ep_host}:${_ep_port}" + else + ts "WARNING: gateway NOT reachable at ${_ep_host}:${_ep_port} — supervisor may fail to connect" + fi + fi + ts "OPENSHELL_ENDPOINT=${OPENSHELL_ENDPOINT}" +fi +if [ -n "${OPENSHELL_SANDBOX_ID:-}" ]; then + ts "OPENSHELL_SANDBOX_ID=${OPENSHELL_SANDBOX_ID}" +fi + ts "starting openshell-sandbox supervisor" exec /opt/openshell/bin/openshell-sandbox --workdir /sandbox diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 1ec53f4ed..e5520b5a1 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -4,7 +4,9 @@ use crate::gpu::{ GpuInventory, SubnetAllocator, allocate_vsock_cid, mac_from_sandbox_id, tap_device_name, }; -use crate::rootfs::{extract_sandbox_rootfs_to, sandbox_guest_init_path}; +use crate::rootfs::{ + extract_gpu_sandbox_rootfs_to, extract_sandbox_rootfs_to, sandbox_guest_init_path, +}; use futures::Stream; use nix::errno::Errno; use nix::sys::signal::{Signal, kill}; @@ -188,7 +190,6 @@ pub struct VmDriver { registry: Arc>>, events: broadcast::Sender, gpu_inventory: Option>>, - gpu_count: u32, subnet_allocator: Arc>, } @@ -203,6 +204,9 @@ impl VmDriver { #[cfg(target_os = "linux")] if config.gpu_enabled { check_gpu_privileges()?; + tokio::task::spawn_blocking(crate::cleanup_stale_tap_interfaces) + .await + .map_err(|e| format!("cleanup stale TAP interfaces panicked: {e}"))?; } let state_root = config.state_dir.join("sandboxes"); @@ -222,14 +226,16 @@ impl VmDriver { .map_err(|err| format!("failed to resolve vm driver executable: {err}"))? }; - let (gpu_inventory, gpu_count) = if config.gpu_enabled { + let gpu_inventory = if config.gpu_enabled { let sysfs = SysfsRoot::system(); let inventory = GpuInventory::new(sysfs, &config.state_dir); - let count = inventory.gpu_count(); - tracing::info!(gpu_count = count, "GPU inventory initialized"); - (Some(Arc::new(std::sync::Mutex::new(inventory))), count) + tracing::info!( + gpu_count = inventory.gpu_count(), + "GPU inventory initialized" + ); + Some(Arc::new(std::sync::Mutex::new(inventory))) } else { - (None, 0) + None }; let subnet_allocator = Arc::new(std::sync::Mutex::new(SubnetAllocator::new( @@ -244,19 +250,23 @@ impl VmDriver { registry: Arc::new(Mutex::new(HashMap::new())), events, gpu_inventory, - gpu_count, subnet_allocator, }) } #[must_use] pub fn capabilities(&self) -> GetCapabilitiesResponse { + let gpu_count = self + .gpu_inventory + .as_ref() + .and_then(|inv| inv.lock().ok()) + .map_or(0, |inv| inv.gpu_count()); GetCapabilitiesResponse { driver_name: DRIVER_NAME.to_string(), driver_version: openshell_core::VERSION.to_string(), default_image: String::new(), supports_gpu: self.gpu_inventory.is_some(), - gpu_count: self.gpu_count, + gpu_count, } } @@ -287,7 +297,12 @@ impl VmDriver { .tls_paths() .map_err(Status::failed_precondition)?; let rootfs_for_extract = rootfs.clone(); - tokio::task::spawn_blocking(move || extract_sandbox_rootfs_to(&rootfs_for_extract)) + let extract_fn = if is_gpu { + extract_gpu_sandbox_rootfs_to + } else { + extract_sandbox_rootfs_to + }; + tokio::task::spawn_blocking(move || extract_fn(&rootfs_for_extract)) .await .map_err(|err| Status::internal(format!("sandbox rootfs extraction panicked: {err}")))? .map_err(|err| Status::internal(format!("extract sandbox rootfs failed: {err}")))?; @@ -304,19 +319,28 @@ impl VmDriver { .gpu_inventory .as_ref() .ok_or_else(|| Status::internal("GPU inventory not initialized"))?; - let assignment = inventory + match inventory .lock() - .map_err(|e| Status::internal(format!("GPU inventory lock poisoned: {e}")))? - .assign(&sandbox.id, gpu_device) - .map_err(|e| Status::failed_precondition(e))?; - tracing::info!( - sandbox_id = %sandbox.id, - bdf = %assignment.bdf, - gpu_name = %assignment.name, - iommu_group = assignment.iommu_group, - "assigned GPU to sandbox" - ); - Some(assignment.bdf) + .map_err(|e| Status::internal(format!("GPU inventory lock poisoned: {e}"))) + .and_then(|mut inv| { + inv.assign(&sandbox.id, gpu_device) + .map_err(|e| Status::failed_precondition(e)) + }) { + Ok(assignment) => { + tracing::info!( + sandbox_id = %sandbox.id, + bdf = %assignment.bdf, + gpu_name = %assignment.name, + iommu_group = assignment.iommu_group, + "assigned GPU to sandbox" + ); + Some(assignment.bdf) + } + Err(err) => { + let _ = tokio::fs::remove_dir_all(&state_dir).await; + return Err(err); + } + } } else { None }; @@ -371,9 +395,6 @@ impl VmDriver { command .arg("--vm-mem-mib") .arg(self.config.gpu_mem_mib.to_string()); - command - .arg("--vm-krun-log-level") - .arg(self.config.krun_log_level.to_string()); command.arg("--vm-gpu-bdf").arg(gpu_bdf.as_ref().unwrap()); command.arg("--vm-tap-device").arg(&tap); command @@ -383,18 +404,23 @@ impl VmDriver { command.arg("--vm-vsock-cid").arg(vsock_cid.to_string()); command.arg("--vm-guest-mac").arg(&mac_str); + if let Some(port) = gateway_port_from_endpoint(&self.config.openshell_endpoint) { + command.arg("--vm-gateway-port").arg(port.to_string()); + } + Some(tap_endpoint) } else { command.arg("--vm-vcpus").arg(self.config.vcpus.to_string()); command .arg("--vm-mem-mib") .arg(self.config.mem_mib.to_string()); - command - .arg("--vm-krun-log-level") - .arg(self.config.krun_log_level.to_string()); None }; + command + .arg("--vm-krun-log-level") + .arg(self.config.krun_log_level.to_string()); + for env in build_guest_environment(sandbox, &self.config, endpoint_override.as_deref()) { command.arg("--vm-env").arg(env); } @@ -911,6 +937,10 @@ fn guest_visible_openshell_endpoint(endpoint: &str) -> String { endpoint.to_string() } +fn gateway_port_from_endpoint(endpoint: &str) -> Option { + Url::parse(endpoint).ok().and_then(|url| url.port()) +} + fn guest_visible_openshell_endpoint_for_tap(endpoint: &str, host_ip: &str) -> String { let Ok(mut url) = Url::parse(endpoint) else { return endpoint.to_string(); @@ -1362,7 +1392,6 @@ mod tests { registry: Arc::new(Mutex::new(HashMap::new())), events, gpu_inventory: None, - gpu_count: 0, subnet_allocator: Arc::new(std::sync::Mutex::new(SubnetAllocator::new( Ipv4Addr::new(10, 0, 128, 0), 17, diff --git a/crates/openshell-driver-vm/src/lib.rs b/crates/openshell-driver-vm/src/lib.rs index e4f9e1299..194dde43c 100644 --- a/crates/openshell-driver-vm/src/lib.rs +++ b/crates/openshell-driver-vm/src/lib.rs @@ -10,4 +10,7 @@ mod rootfs; mod runtime; pub use driver::{VmDriver, VmDriverConfig}; -pub use runtime::{VM_RUNTIME_DIR_ENV, VmBackend, VmLaunchConfig, configured_runtime_dir, run_vm}; +pub use runtime::{ + VM_RUNTIME_DIR_ENV, VmBackend, VmLaunchConfig, cleanup_stale_tap_interfaces, + configured_runtime_dir, run_vm, +}; diff --git a/crates/openshell-driver-vm/src/main.rs b/crates/openshell-driver-vm/src/main.rs index 38311f745..94169a61f 100644 --- a/crates/openshell-driver-vm/src/main.rs +++ b/crates/openshell-driver-vm/src/main.rs @@ -5,6 +5,8 @@ use clap::Parser; use miette::{IntoDiagnostic, Result}; use openshell_core::VERSION; use openshell_core::proto::compute::v1::compute_driver_server::ComputeDriverServer; +#[cfg(target_os = "macos")] +use openshell_driver_vm::{VM_RUNTIME_DIR_ENV, configured_runtime_dir}; use openshell_driver_vm::{VmBackend, VmDriver, VmDriverConfig, VmLaunchConfig, procguard, run_vm}; use std::net::SocketAddr; use std::path::PathBuf; @@ -120,6 +122,9 @@ struct Args { #[arg(long, hide = true)] vm_guest_mac: Option, + + #[arg(long, hide = true)] + vm_gateway_port: Option, } #[tokio::main] @@ -246,6 +251,7 @@ fn build_vm_launch_config(args: &Args) -> std::result::Result &'static str { } pub fn extract_sandbox_rootfs_to(dest: &Path) -> Result<(), String> { - if ROOTFS.is_empty() { - return Err( - "sandbox rootfs not embedded. Build openshell-driver-vm with OPENSHELL_VM_RUNTIME_COMPRESSED_DIR set or run `mise run vm:setup` first" - .to_string(), - ); + extract_variant( + ROOTFS, + "sandbox", + "sandbox rootfs not embedded. Build openshell-driver-vm with OPENSHELL_VM_RUNTIME_COMPRESSED_DIR set or run `mise run vm:setup` first", + dest, + ) +} + +pub fn extract_gpu_sandbox_rootfs_to(dest: &Path) -> Result<(), String> { + extract_variant( + ROOTFS_GPU, + "sandbox-gpu", + "GPU sandbox rootfs not embedded. Build with `mise run vm:rootfs -- --gpu` first", + dest, + ) +} + +fn extract_variant(blob: &[u8], variant: &str, empty_msg: &str, dest: &Path) -> Result<(), String> { + if blob.is_empty() { + return Err(empty_msg.to_string()); } - let expected_marker = format!("{}:sandbox", env!("CARGO_PKG_VERSION")); + let expected_marker = format!("{}:{variant}", env!("CARGO_PKG_VERSION")); let marker_path = dest.join(ROOTFS_VARIANT_MARKER); if dest.is_dir() @@ -37,22 +53,25 @@ pub fn extract_sandbox_rootfs_to(dest: &Path) -> Result<(), String> { .map_err(|e| format!("remove old rootfs {}: {e}", dest.display()))?; } - extract_rootfs_to(dest)?; + unpack_zstd_tar(blob, variant, dest)?; prepare_sandbox_rootfs(dest)?; fs::write(marker_path, format!("{expected_marker}\n")) .map_err(|e| format!("write rootfs variant marker: {e}"))?; Ok(()) } -fn extract_rootfs_to(dest: &Path) -> Result<(), String> { +fn unpack_zstd_tar(blob: &[u8], label: &str, dest: &Path) -> Result<(), String> { fs::create_dir_all(dest).map_err(|e| format!("create rootfs dir {}: {e}", dest.display()))?; - let decoder = - zstd::Decoder::new(Cursor::new(ROOTFS)).map_err(|e| format!("decompress rootfs: {e}"))?; + let decoder = zstd::Decoder::new(Cursor::new(blob)) + .map_err(|e| format!("decompress {label} rootfs: {e}"))?; let mut archive = tar::Archive::new(decoder); - archive - .unpack(dest) - .map_err(|e| format!("extract rootfs tarball into {}: {e}", dest.display())) + archive.unpack(dest).map_err(|e| { + format!( + "extract {label} rootfs tarball into {}: {e}", + dest.display() + ) + }) } fn prepare_sandbox_rootfs(rootfs: &Path) -> Result<(), String> { diff --git a/crates/openshell-driver-vm/src/runtime.rs b/crates/openshell-driver-vm/src/runtime.rs index 85054ab2a..62f2e314c 100644 --- a/crates/openshell-driver-vm/src/runtime.rs +++ b/crates/openshell-driver-vm/src/runtime.rs @@ -47,6 +47,7 @@ pub struct VmLaunchConfig { pub host_ip: Option, pub vsock_cid: Option, pub guest_mac: Option, + pub gateway_port: Option, } pub fn run_vm(config: &VmLaunchConfig) -> Result<(), String> { @@ -108,10 +109,11 @@ fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { std::fs::create_dir_all(&shm_path).map_err(|e| format!("create shm dir: {e}"))?; - let runtime_dir = configured_runtime_dir()?; + let runtime_dir = qemu_runtime_dir()?; - setup_tap_networking(tap_device, host_ip)?; - let mut tap_guard = TapGuard::new(tap_device.to_string(), host_ip.to_string()); + let gw_port = config.gateway_port.unwrap_or(0); + setup_tap_networking(tap_device, host_ip, gw_port)?; + let mut tap_guard = TapGuard::new(tap_device.to_string(), host_ip.to_string(), gw_port); let virtiofsd_log = sandbox_dir.join("virtiofsd.log"); let virtiofsd_log_file = @@ -200,7 +202,11 @@ fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { "tap,id=net0,ifname={tap_device},script=no,downscript=no" )) .arg("-device") - .arg(format!("virtio-net-pci,netdev=net0,mac={guest_mac}")) + .arg("pcie-root-port,id=net_root,slot=3") + .arg("-device") + .arg(format!( + "virtio-net-pci-non-transitional,netdev=net0,mac={guest_mac},bus=net_root" + )) .arg("-device") .arg("pcie-root-port,id=vsock_root,slot=1") .arg("-device") @@ -247,7 +253,7 @@ fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { } virtiofsd_guard.disarm(); GVPROXY_PID.store(0, Ordering::Relaxed); - teardown_tap_networking(tap_device, host_ip); + teardown_tap_networking(tap_device, host_ip, gw_port); tap_guard.disarm(); let _ = std::fs::remove_dir_all(&shm_path); let _ = std::fs::remove_dir_all(&virtiofsd_sock_dir); @@ -338,7 +344,58 @@ fn host_dns_server() -> Option { None } -fn setup_tap_networking(tap_device: &str, host_ip: &str) -> Result<(), String> { +/// Remove leftover `vmtap-*` interfaces from previous driver runs that +/// were not torn down (e.g. the launcher was SIGKILLed before teardown). +/// Called once at driver startup so stale interfaces cannot cause subnet +/// routing conflicts with newly allocated TAPs. +pub fn cleanup_stale_tap_interfaces() { + let Ok(entries) = std::fs::read_dir("/sys/class/net") else { + return; + }; + for entry in entries.flatten() { + let name = entry.file_name(); + let Some(name) = name.to_str() else { + continue; + }; + if !name.starts_with("vmtap-") { + continue; + } + // Read the IP address so we can clean up iptables rules too. + // Port 0 tells teardown we don't know the original gateway port; + // the blanket legacy rule is still cleaned up best-effort. + let ip = read_tap_host_ip(name); + if let Some(ref host_ip) = ip { + teardown_tap_networking(name, host_ip, 0); + } else { + let _ = run_cmd("ip", &["link", "set", name, "down"]); + let _ = run_cmd("ip", &["tuntap", "del", "dev", name, "mode", "tap"]); + } + tracing::warn!(interface = %name, "removed stale TAP interface from previous run"); + } +} + +/// Read the first IPv4 address assigned to a network interface. +fn read_tap_host_ip(device: &str) -> Option { + let output = StdCommand::new("ip") + .args(["-4", "-o", "addr", "show", "dev", device]) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::null()) + .output() + .ok()?; + let stdout = String::from_utf8_lossy(&output.stdout); + // Format: "28: vmtap-xxx inet 10.0.128.1/30 ..." + for token in stdout.split_whitespace() { + if let Some((ip, _prefix)) = token.split_once('/') { + if ip.parse::().is_ok() { + return Some(ip.to_string()); + } + } + } + None +} + +fn setup_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) -> Result<(), String> { run_cmd("ip", &["tuntap", "add", "dev", tap_device, "mode", "tap"])?; run_cmd( "ip", @@ -346,6 +403,13 @@ fn setup_tap_networking(tap_device: &str, host_ip: &str) -> Result<(), String> { )?; run_cmd("ip", &["link", "set", tap_device, "up"])?; + // Deprioritize routes through down interfaces so a stale vmtap-* + // that somehow survives cleanup cannot shadow the active one. + let _ = std::fs::write( + format!("/proc/sys/net/ipv4/conf/{tap_device}/ignore_routes_with_linkdown"), + "1", + ); + enable_ip_forwarding()?; let subnet = tap_subnet_from_host_ip(host_ip); @@ -413,11 +477,28 @@ fn setup_tap_networking(tap_device: &str, host_ip: &str) -> Result<(), String> { "ACCEPT", ], )?; + // Allow guest → host traffic only to the gateway gRPC port. + // Previous versions accepted ALL inbound traffic from the TAP + // interface; scope to the specific port so the guest cannot reach + // other host services. + let port_str = gateway_port.to_string(); + let _ = run_cmd( + "iptables", + &[ + "-D", "INPUT", "-i", tap_device, "-p", "tcp", "--dport", &port_str, "-j", "ACCEPT", + ], + ); + run_cmd( + "iptables", + &[ + "-A", "INPUT", "-i", tap_device, "-p", "tcp", "--dport", &port_str, "-j", "ACCEPT", + ], + )?; Ok(()) } -fn teardown_tap_networking(tap_device: &str, host_ip: &str) { +fn teardown_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) { let subnet = tap_subnet_from_host_ip(host_ip); let _ = run_cmd( "iptables", @@ -438,6 +519,21 @@ fn teardown_tap_networking(tap_device: &str, host_ip: &str) { "iptables", &["-D", "FORWARD", "-i", tap_device, "-j", "ACCEPT"], ); + // Remove the port-scoped INPUT rule. Also try the legacy blanket + // rule so stale rules from older driver versions are cleaned up. + if gateway_port > 0 { + let port_str = gateway_port.to_string(); + let _ = run_cmd( + "iptables", + &[ + "-D", "INPUT", "-i", tap_device, "-p", "tcp", "--dport", &port_str, "-j", "ACCEPT", + ], + ); + } + let _ = run_cmd( + "iptables", + &["-D", "INPUT", "-i", tap_device, "-j", "ACCEPT"], + ); let _ = run_cmd( "iptables", &[ @@ -490,14 +586,16 @@ fn run_cmd(cmd: &str, args: &[&str]) -> Result<(), String> { struct TapGuard { tap_device: String, host_ip: String, + gateway_port: u16, disarmed: bool, } impl TapGuard { - fn new(tap_device: String, host_ip: String) -> Self { + fn new(tap_device: String, host_ip: String, gateway_port: u16) -> Self { Self { tap_device, host_ip, + gateway_port, disarmed: false, } } @@ -510,7 +608,7 @@ impl TapGuard { impl Drop for TapGuard { fn drop(&mut self) { if !self.disarmed { - teardown_tap_networking(&self.tap_device, &self.host_ip); + teardown_tap_networking(&self.tap_device, &self.host_ip, self.gateway_port); } } } @@ -788,6 +886,14 @@ pub fn configured_runtime_dir() -> Result { embedded_runtime::ensure_runtime_extracted() } +fn qemu_runtime_dir() -> Result { + configured_runtime_dir().map_err(|_| { + "QEMU backend requires OPENSHELL_VM_RUNTIME_DIR to be set (pointing to a directory \ + containing vmlinux). Set the env var or run `mise run vm:setup`." + .to_string() + }) +} + #[cfg(target_os = "macos")] fn configure_runtime_loader_env(runtime_dir: &Path) -> Result<(), String> { let existing = std::env::var_os("DYLD_FALLBACK_LIBRARY_PATH"); diff --git a/crates/openshell-driver-vm/start.sh b/crates/openshell-driver-vm/start.sh index 675bb4c2e..d98bb7b91 100755 --- a/crates/openshell-driver-vm/start.sh +++ b/crates/openshell-driver-vm/start.sh @@ -4,6 +4,16 @@ set -euo pipefail +# Under sudo, PATH is reset and user-local tools (mise, cargo) disappear. +# Restore the invoking user's tool directories so mise and its shims work. +if [ -n "${SUDO_USER:-}" ]; then + _sudo_home=$(getent passwd "${SUDO_USER}" | cut -d: -f6) + for _p in "${_sudo_home}/.local/bin" "${_sudo_home}/.local/share/mise/shims" "${_sudo_home}/.cargo/bin"; do + [ -d "${_p}" ] && PATH="${_p}:${PATH}" + done + export PATH +fi + ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" CLI_BIN="${ROOT}/scripts/bin/openshell" COMPRESSED_DIR="${ROOT}/target/vm-runtime-compressed" @@ -80,19 +90,19 @@ check_supervisor_cross_toolchain() { fi } -if [ ! -s "${COMPRESSED_DIR}/rootfs.tar.zst" ]; then +if [ ! -s "${OPENSHELL_VM_RUNTIME_COMPRESSED_DIR}/rootfs.tar.zst" ]; then check_supervisor_cross_toolchain echo "==> Building base VM rootfs tarball" mise run vm:rootfs -- --base fi -if [ "${OPENSHELL_VM_GPU:-}" = "true" ] && [ ! -s "${COMPRESSED_DIR}/rootfs-gpu.tar.zst" ]; then +if [ "${OPENSHELL_VM_GPU:-}" = "true" ] && [ ! -s "${OPENSHELL_VM_RUNTIME_COMPRESSED_DIR}/rootfs-gpu.tar.zst" ]; then check_supervisor_cross_toolchain echo "==> Building GPU VM rootfs tarball" mise run vm:rootfs -- --gpu fi -if [ ! -s "${COMPRESSED_DIR}/rootfs.tar.zst" ] || ! find "${COMPRESSED_DIR}" -maxdepth 1 -name 'libkrun*.zst' | grep -q .; then +if [ ! -s "${OPENSHELL_VM_RUNTIME_COMPRESSED_DIR}/rootfs.tar.zst" ] || ! find "${OPENSHELL_VM_RUNTIME_COMPRESSED_DIR}" -maxdepth 1 -name 'libkrun*.zst' | grep -q .; then echo "==> Preparing embedded VM runtime" mise run vm:setup fi @@ -119,17 +129,36 @@ export OPENSHELL_SSH_GATEWAY_PORT="${OPENSHELL_SSH_GATEWAY_PORT:-${SERVER_PORT}} export OPENSHELL_SSH_HANDSHAKE_SECRET="${OPENSHELL_SSH_HANDSHAKE_SECRET:-dev-vm-driver-secret}" export OPENSHELL_VM_DRIVER_STATE_DIR="${STATE_DIR}" +# Resolve the VM runtime directory (contains vmlinux, virtiofsd, etc.) +# so the child --internal-run-vm process can find it under sudo. +if [ -z "${OPENSHELL_VM_RUNTIME_DIR:-}" ]; then + _candidate="${HOME}/.local/share/openshell/vm-runtime/0.0.0" + if [ -n "${SUDO_USER:-}" ]; then + _sudo_home=$(getent passwd "${SUDO_USER}" | cut -d: -f6) + _candidate="${_sudo_home}/.local/share/openshell/vm-runtime/0.0.0" + fi + if [ -f "${_candidate}/vmlinux" ]; then + export OPENSHELL_VM_RUNTIME_DIR="${_candidate}" + fi +fi + echo "==> Registering gateway" echo " Name: ${GATEWAY_NAME}" echo " Endpoint: ${LOCAL_GATEWAY_ENDPOINT}" echo " Driver: ${OPENSHELL_DRIVER_DIR}/openshell-driver-vm" +# GPU passthrough requires root, but gateway config must be written to the +# real user's home directory — not /root/.config/openshell/. +# Unset XDG_CONFIG_HOME so the CLI falls back to $HOME/.config (sudo -u +# sets HOME correctly but may inherit XDG_CONFIG_HOME from the root env). if [ -n "${SUDO_USER:-}" ]; then - sudo -u "${SUDO_USER}" "${CLI_BIN}" gateway destroy --name "${GATEWAY_NAME}" 2>/dev/null || true - sudo -u "${SUDO_USER}" "${CLI_BIN}" gateway add --name "${GATEWAY_NAME}" "${LOCAL_GATEWAY_ENDPOINT}" + sudo -u "${SUDO_USER}" env -u XDG_CONFIG_HOME "PATH=${PATH}" "${CLI_BIN}" gateway destroy --name "${GATEWAY_NAME}" 2>/dev/null || true + sudo -u "${SUDO_USER}" env -u XDG_CONFIG_HOME "PATH=${PATH}" "${CLI_BIN}" gateway add --name "${GATEWAY_NAME}" "${LOCAL_GATEWAY_ENDPOINT}" + sudo -u "${SUDO_USER}" env -u XDG_CONFIG_HOME "PATH=${PATH}" "${CLI_BIN}" gateway select "${GATEWAY_NAME}" else "${CLI_BIN}" gateway destroy --name "${GATEWAY_NAME}" 2>/dev/null || true "${CLI_BIN}" gateway add --name "${GATEWAY_NAME}" "${LOCAL_GATEWAY_ENDPOINT}" + "${CLI_BIN}" gateway select "${GATEWAY_NAME}" fi echo "==> Starting OpenShell server with VM compute driver" diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 34ee80bb5..262859c33 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -1567,6 +1567,68 @@ mod baseline_tests { } } +/// Returns `true` if the error is transient and worth retrying. +/// +/// Walks the `miette::Report` error chain looking for a `tonic::Status`. If +/// found, only the gRPC codes that represent transient failures are retryable. +/// If no `tonic::Status` is present (e.g. a raw connection error), assume the +/// failure is transient. +fn is_retryable_error(err: &miette::Report) -> bool { + let mut source: Option<&dyn std::error::Error> = Some(err.as_ref()); + while let Some(e) = source { + if let Some(status) = e.downcast_ref::() { + return matches!( + status.code(), + tonic::Code::Unavailable + | tonic::Code::DeadlineExceeded + | tonic::Code::ResourceExhausted + | tonic::Code::Aborted + | tonic::Code::Internal + | tonic::Code::Unknown + ); + } + source = e.source(); + } + true +} + +/// Retry a gRPC operation with exponential backoff (capped at 4 s). +/// +/// Non-transient gRPC errors (e.g. `NOT_FOUND`, `INVALID_ARGUMENT`, +/// `PERMISSION_DENIED`) are returned immediately without retrying. +async fn grpc_retry(op_name: &str, f: F) -> miette::Result +where + F: Fn() -> Fut, + Fut: std::future::Future>, +{ + let mut last_err = None; + for attempt in 1..=5u32 { + match f().await { + Ok(val) => return Ok(val), + Err(e) => { + if !is_retryable_error(&e) { + return Err(e); + } + if attempt < 5 { + warn!( + attempt, + max_attempts = 5, + error = %e, + "{op_name} failed, retrying" + ); + let backoff = Duration::from_secs((1u64 << (attempt - 1)).min(4)); + tokio::time::sleep(backoff).await; + } + last_err = Some(e); + } + } + } + Err(miette::miette!( + "{op_name} failed after 5 attempts: {}", + last_err.expect("loop executed at least once") + )) +} + /// Load sandbox policy from local files or gRPC. /// /// Priority: @@ -1627,7 +1689,8 @@ async fn load_policy( endpoint = %endpoint, "Fetching sandbox policy via gRPC" ); - let proto_policy = grpc_client::fetch_policy(endpoint, id).await?; + let proto_policy = + grpc_retry("Policy fetch", || grpc_client::fetch_policy(endpoint, id)).await?; let mut proto_policy = match proto_policy { Some(p) => p, @@ -1656,7 +1719,10 @@ async fn load_policy( // Sync and re-fetch over a single connection to avoid extra // TLS handshakes. - grpc_client::discover_and_sync_policy(endpoint, id, sandbox, &discovered).await? + grpc_retry("Policy discovery sync", || { + grpc_client::discover_and_sync_policy(endpoint, id, sandbox, &discovered) + }) + .await? } }; diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 7b56ad9d4..a218d18a8 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -1517,6 +1517,7 @@ impl ComputeDriver for NoopTestDriver { driver_version: "test".to_string(), default_image: "openshell/sandbox:test".to_string(), supports_gpu: false, + gpu_count: 0, }, )) } diff --git a/crates/openshell-vfio/src/lib.rs b/crates/openshell-vfio/src/lib.rs index 74a3ac38f..e7276b301 100644 --- a/crates/openshell-vfio/src/lib.rs +++ b/crates/openshell-vfio/src/lib.rs @@ -7,13 +7,26 @@ //! the VFIO subsystem. All sysfs access goes through [`SysfsRoot`] so the //! entire stack is testable without root or real hardware. +use std::collections::HashMap; use std::fs; use std::path::{Path, PathBuf}; +use std::sync::Mutex; +use std::time::Duration; const NVIDIA_VENDOR_ID: &str = "0x10de"; const GPU_CLASS_DISPLAY_VGA: &str = "0x030000"; const GPU_CLASS_DISPLAY_3D: u32 = 0x0302; +const VFIO_BIND_POLL_INTERVAL: Duration = Duration::from_millis(100); +const VFIO_BIND_MAX_POLL_ATTEMPTS: u32 = 20; + +/// Reference counter for vendor:device ID registrations in the vfio-pci +/// match table. Multiple GPUs may share the same vendor:device pair (e.g., +/// two A100s). We only write to the kernel's `new_id`/`remove_id` sysfs +/// files when the first GPU registers or the last GPU deregisters an ID. +static VFIO_ID_REFCOUNTS: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| Mutex::new(HashMap::new())); + // --------------------------------------------------------------------------- // Errors // --------------------------------------------------------------------------- @@ -88,6 +101,14 @@ impl SysfsRoot { self.base.join("bus/pci/drivers_probe") } + pub fn vfio_pci_new_id(&self) -> PathBuf { + self.base.join("bus/pci/drivers/vfio-pci/new_id") + } + + pub fn vfio_pci_remove_id(&self) -> PathBuf { + self.base.join("bus/pci/drivers/vfio-pci/remove_id") + } + pub fn iommu_group(&self, bdf: &str) -> Result { let link = self.pci_device(bdf).join("iommu_group"); let target = fs::read_link(&link).map_err(|_| VfioError::NoIommuGroup { @@ -152,6 +173,10 @@ pub struct GpuBindGuard { companion_bdfs: Vec, sysfs: SysfsRoot, disarmed: bool, + /// Cached "VVVV DDDD" string captured at bind time so that + /// deregistration from vfio-pci's match table succeeds even if the + /// device's sysfs entries have disappeared (e.g. physical removal). + vfio_id: Option, } impl GpuBindGuard { @@ -165,18 +190,47 @@ impl GpuBindGuard { } } +impl GpuBindGuard { + /// Deregister the cached vfio-pci match-table entry, then restore the + /// device to its host driver. + /// + /// Using the cached ID avoids re-reading vendor/device from sysfs, + /// which would fail if the GPU has been physically removed. + fn restore_with_cached_id(&self) { + if let Some(ref id_str) = self.vfio_id { + deregister_vfio_id_by_value(&self.sysfs, id_str); + } + + for peer in &self.companion_bdfs { + if let Err(err) = restore_gpu_to_host_driver(&self.sysfs, peer) { + tracing::error!(bdf = %peer, error = %err, "failed to restore companion device to host driver"); + } + } + + if let Err(err) = + restore_gpu_to_host_driver_ex(&self.sysfs, &self.bdf, self.vfio_id.is_some()) + { + tracing::error!(bdf = %self.bdf, error = %err, "failed to restore GPU to host driver"); + } + } +} + impl Drop for GpuBindGuard { fn drop(&mut self) { if self.disarmed { return; } - for peer in &self.companion_bdfs { - if let Err(err) = restore_gpu_to_host_driver(&self.sysfs, peer) { - tracing::error!(bdf = %peer, error = %err, "failed to restore companion device to host driver on drop"); + if self.vfio_id.is_some() { + self.restore_with_cached_id(); + } else { + for peer in &self.companion_bdfs { + if let Err(err) = restore_gpu_to_host_driver(&self.sysfs, peer) { + tracing::error!(bdf = %peer, error = %err, "failed to restore companion device to host driver on drop"); + } + } + if let Err(err) = restore_gpu_to_host_driver(&self.sysfs, &self.bdf) { + tracing::error!(bdf = %self.bdf, error = %err, "failed to restore GPU to host driver on drop"); } - } - if let Err(err) = restore_gpu_to_host_driver(&self.sysfs, &self.bdf) { - tracing::error!(bdf = %self.bdf, error = %err, "failed to restore GPU to host driver on drop"); } } } @@ -364,6 +418,127 @@ pub fn probe_host_nvidia_vfio_readiness(sysfs: &SysfsRoot) -> Vec { // Bind / unbind // --------------------------------------------------------------------------- +/// Read vendor and device IDs from sysfs and format as `"VVVV DDDD"` (no `0x` prefix). +fn vfio_id_string(sysfs: &SysfsRoot, bdf: &str) -> Option { + let dev_dir = sysfs.pci_device(bdf); + let vendor = read_sysfs_trimmed(&dev_dir.join("vendor")).ok()?; + let device = read_sysfs_trimmed(&dev_dir.join("device")).ok()?; + let vendor_hex = vendor.strip_prefix("0x").unwrap_or(&vendor); + let device_hex = device.strip_prefix("0x").unwrap_or(&device); + Some(format!("{vendor_hex} {device_hex}")) +} + +/// Best-effort registration of a device's vendor:device ID with `vfio-pci`. +/// +/// Some kernel configurations require the ID to be pre-registered in +/// `/sys/bus/pci/drivers/vfio-pci/new_id` before `drivers_probe` will +/// bind the device, even when `driver_override` is set. Writing an +/// already-registered ID returns `EEXIST`, which we silently ignore. +fn register_vfio_new_id(sysfs: &SysfsRoot, bdf: &str) { + let Some(id_str) = vfio_id_string(sysfs, bdf) else { + return; + }; + + let should_write = { + let mut map = VFIO_ID_REFCOUNTS.lock().unwrap(); + let count = map.entry(id_str.clone()).or_insert(0); + *count += 1; + *count == 1 + }; + + if !should_write { + tracing::debug!( + bdf, id = %id_str, + "vfio-pci new_id already registered by another GPU, refcount incremented" + ); + return; + } + + let new_id_path = sysfs.vfio_pci_new_id(); + match write_sysfs(&new_id_path, &id_str) { + Ok(()) => { + tracing::debug!(bdf, id = %id_str, "registered vfio-pci new_id"); + } + Err(_) => { + tracing::debug!( + bdf, id = %id_str, + "vfio-pci new_id write skipped (already registered or driver not loaded)" + ); + } + } +} + +/// Best-effort deregistration of a device's vendor:device ID from `vfio-pci`. +/// +/// Reverses the effect of [`register_vfio_new_id`] by writing to +/// `/sys/bus/pci/drivers/vfio-pci/remove_id`. This prevents vfio-pci +/// from winning the probe race against the host driver when +/// `drivers_probe` runs during restore. +/// +/// `ENODEV` is silently ignored (the ID may never have been registered +/// or was already removed). +fn deregister_vfio_new_id(sysfs: &SysfsRoot, bdf: &str) { + let Some(id_str) = vfio_id_string(sysfs, bdf) else { + return; + }; + + let should_write = { + let mut map = VFIO_ID_REFCOUNTS.lock().unwrap(); + match map.get_mut(&id_str) { + Some(count) if *count > 1 => { + *count -= 1; + false + } + Some(_) => { + map.remove(&id_str); + true + } + None => true, + } + }; + + if !should_write { + tracing::debug!( + bdf, id = %id_str, + "vfio-pci remove_id skipped (other GPUs still using this ID)" + ); + return; + } + + let remove_id_path = sysfs.vfio_pci_remove_id(); + match write_sysfs(&remove_id_path, &id_str) { + Ok(()) => { + tracing::debug!(bdf, id = %id_str, "deregistered vfio-pci new_id"); + } + Err(_) => { + tracing::debug!( + bdf, id = %id_str, + "vfio-pci remove_id write skipped (not registered or already removed)" + ); + } + } +} + +/// Best-effort deregistration using a pre-captured ID string. +/// +/// Unlike [`deregister_vfio_new_id`], this does not read vendor/device +/// from sysfs at call time, making it reliable even when the device has +/// been physically removed or sysfs is otherwise inaccessible. +fn deregister_vfio_id_by_value(sysfs: &SysfsRoot, id_str: &str) { + let remove_id_path = sysfs.vfio_pci_remove_id(); + match write_sysfs(&remove_id_path, id_str) { + Ok(()) => { + tracing::debug!(id = %id_str, "deregistered vfio-pci new_id (by cached value)"); + } + Err(_) => { + tracing::debug!( + id = %id_str, + "vfio-pci remove_id write skipped (not registered or already removed)" + ); + } + } +} + /// Bind a single PCI device to `vfio-pci`. Skips devices already bound. fn bind_device_to_vfio(sysfs: &SysfsRoot, bdf: &str) -> Result { if let Some(drv) = current_driver_name(sysfs, bdf) { @@ -378,31 +553,50 @@ fn bind_device_to_vfio(sysfs: &SysfsRoot, bdf: &str) -> Result tracing::info!(bdf, driver = %drv, "unbound device from current driver"); } + register_vfio_new_id(sysfs, bdf); + let override_path = sysfs.pci_device(bdf).join("driver_override"); - write_sysfs(&override_path, "vfio-pci").map_err(|e| VfioError::BindFailed { - bdf: bdf.to_string(), - reason: format!("driver_override: {e}"), - })?; + if let Err(e) = write_sysfs(&override_path, "vfio-pci") { + deregister_vfio_new_id(sysfs, bdf); + return Err(VfioError::BindFailed { + bdf: bdf.to_string(), + reason: format!("driver_override: {e}"), + }); + } - write_sysfs(&sysfs.drivers_probe(), bdf).map_err(|e| VfioError::BindFailed { - bdf: bdf.to_string(), - reason: format!("drivers_probe: {e}"), - })?; + if let Err(e) = write_sysfs(&sysfs.drivers_probe(), bdf) { + deregister_vfio_new_id(sysfs, bdf); + return Err(VfioError::BindFailed { + bdf: bdf.to_string(), + reason: format!("drivers_probe: {e}"), + }); + } - match current_driver_name(sysfs, bdf) { - Some(ref drv) if drv == "vfio-pci" => {} - other => { - return Err(VfioError::BindFailed { - bdf: bdf.to_string(), - reason: format!( - "after probe, driver is {:?} instead of vfio-pci", - other.as_deref().unwrap_or("") - ), - }); + if matches!(current_driver_name(sysfs, bdf).as_deref(), Some("vfio-pci")) { + return Ok(true); + } + + // The kernel processes drivers_probe asynchronously on some systems; + // poll briefly to let the driver attach before declaring failure. + for _ in 0..VFIO_BIND_MAX_POLL_ATTEMPTS { + std::thread::sleep(VFIO_BIND_POLL_INTERVAL); + if matches!(current_driver_name(sysfs, bdf).as_deref(), Some("vfio-pci")) { + tracing::debug!(bdf, "vfio-pci binding confirmed after polling"); + return Ok(true); } } - Ok(true) + deregister_vfio_new_id(sysfs, bdf); + Err(VfioError::BindFailed { + bdf: bdf.to_string(), + reason: format!( + "after drivers_probe with {}ms polling, driver is {:?} instead of vfio-pci", + VFIO_BIND_MAX_POLL_ATTEMPTS as u64 * VFIO_BIND_POLL_INTERVAL.as_millis() as u64, + current_driver_name(sysfs, bdf) + .as_deref() + .unwrap_or("") + ), + }) } /// Bind a GPU to `vfio-pci`, returning an RAII guard that restores it on drop. @@ -479,18 +673,41 @@ pub fn prepare_gpu_for_passthrough( } } + let vfio_id = vfio_id_string(sysfs, bdf); + Ok(GpuBindGuard { bdf: bdf.to_string(), companion_bdfs: bound_companions, sysfs: sysfs.clone(), disarmed: false, + vfio_id, }) } /// Restore a GPU from `vfio-pci` back to the host's default driver. fn restore_gpu_to_host_driver(sysfs: &SysfsRoot, bdf: &str) -> Result<(), VfioError> { + restore_gpu_to_host_driver_ex(sysfs, bdf, false) +} + +/// Inner restore implementation. +/// +/// When `skip_deregister` is `true` the caller has already removed the +/// device's vendor:device ID from vfio-pci's match table (e.g. via a +/// cached value), so we skip the sysfs-based deregistration. +fn restore_gpu_to_host_driver_ex( + sysfs: &SysfsRoot, + bdf: &str, + skip_deregister: bool, +) -> Result<(), VfioError> { let dev_dir = sysfs.pci_device(bdf); + if !skip_deregister { + // Deregister the device ID from vfio-pci's match table before + // unbind+reprobe. Without this, drivers_probe re-binds to vfio-pci + // via the still-registered new_id entry. + deregister_vfio_new_id(sysfs, bdf); + } + let unbind_path = dev_dir.join("driver/unbind"); if unbind_path.exists() { write_sysfs(&unbind_path, bdf).map_err(|e| VfioError::UnbindFailed { @@ -527,7 +744,9 @@ fn restore_gpu_to_host_driver(sysfs: &SysfsRoot, bdf: &str) -> Result<(), VfioEr /// /// Loads persisted state, checks each GPU, and restores any that are still /// bound to `vfio-pci`. Returns the list of BDFs that were restored. -/// Removes the state file after reconciliation. +/// Removes the state file only when all bindings are resolved; rewrites it +/// with the remaining entries when some restorations fail so they can be +/// retried on the next process start. pub fn reconcile_stale_bindings(sysfs: &SysfsRoot, state_path: &Path) -> Vec { let state = match GpuBindState::load(state_path) { Ok(s) => s, @@ -537,7 +756,14 @@ pub fn reconcile_stale_bindings(sysfs: &SysfsRoot, state_path: &Path) -> Vec Vec Vec Vec { + if let Err(err) = fs::write(state_path, json) { + tracing::error!(%err, path = %state_path.display(), "failed to persist remaining stale bindings"); + } else { + tracing::warn!( + count = remaining.bindings.len(), + "some GPU bindings could not be restored; state preserved for retry" + ); + } + } + Err(err) => { + tracing::error!(%err, "failed to serialize remaining stale bindings"); + } + } } restored @@ -627,6 +877,13 @@ mod tests { symlink(&dev, group_devices_dir.join(bdf)).unwrap(); } + /// Remove a specific vendor:device key from the global refcount map. + /// Used by tests to clean up their own entries without disturbing + /// parallel tests that hold refcounts for different device IDs. + fn clear_vfio_id_refcount(id: &str) { + VFIO_ID_REFCOUNTS.lock().unwrap().remove(id); + } + // -- validate_bdf ------------------------------------------------------- #[test] @@ -778,6 +1035,14 @@ mod tests { sysfs.drivers_probe(), PathBuf::from("/sys/bus/pci/drivers_probe") ); + assert_eq!( + sysfs.vfio_pci_new_id(), + PathBuf::from("/sys/bus/pci/drivers/vfio-pci/new_id") + ); + assert_eq!( + sysfs.vfio_pci_remove_id(), + PathBuf::from("/sys/bus/pci/drivers/vfio-pci/remove_id") + ); let custom = SysfsRoot::new("/tmp/test-sys"); assert_eq!( @@ -863,6 +1128,162 @@ mod tests { assert_eq!(devices[0], "0000:2d:00.0"); } + // -- register_vfio_new_id ----------------------------------------------- + + #[test] + fn test_register_vfio_new_id_writes_vendor_device() { + clear_vfio_id_refcount("10de 26b3"); + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x26b3", + "0x030000", + 42, + ); + + let new_id_path = sysfs.vfio_pci_new_id(); + fs::create_dir_all(new_id_path.parent().unwrap()).unwrap(); + fs::write(&new_id_path, "").unwrap(); + + register_vfio_new_id(&sysfs, "0000:2d:00.0"); + + let written = fs::read_to_string(&new_id_path).unwrap(); + assert_eq!(written, "10de 26b3"); + } + + #[test] + fn test_register_vfio_new_id_ignores_missing_new_id_file() { + clear_vfio_id_refcount("10de 26b4"); + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x26b4", + "0x030000", + 42, + ); + + // Don't create the new_id file — should not panic or error + register_vfio_new_id(&sysfs, "0000:2d:00.0"); + } + + // -- deregister_vfio_new_id --------------------------------------------- + + #[test] + fn test_deregister_vfio_new_id_writes_vendor_device() { + clear_vfio_id_refcount("10de 26b5"); + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x26b5", + "0x030000", + 42, + ); + + let remove_id_path = sysfs.vfio_pci_remove_id(); + fs::create_dir_all(remove_id_path.parent().unwrap()).unwrap(); + fs::write(&remove_id_path, "").unwrap(); + + deregister_vfio_new_id(&sysfs, "0000:2d:00.0"); + + let written = fs::read_to_string(&remove_id_path).unwrap(); + assert_eq!(written, "10de 26b5"); + } + + #[test] + fn test_deregister_vfio_new_id_ignores_missing_remove_id_file() { + clear_vfio_id_refcount("10de 26b6"); + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x26b6", + "0x030000", + 42, + ); + + deregister_vfio_new_id(&sysfs, "0000:2d:00.0"); + } + + // -- refcount safety ---------------------------------------------------- + + #[test] + fn test_register_deregister_refcount() { + clear_vfio_id_refcount("10de 26b8"); + let (tmp, sysfs) = setup_mock_sysfs(); + + // Two GPUs with the same vendor:device ID (e.g., two A100s). + // Uses 0x26b8 — unique to this test to avoid parallel interference. + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x26b8", + "0x030000", + 42, + ); + create_pci_device( + &sysfs, + tmp.path(), + "0000:3b:00.0", + "0x10de", + "0x26b8", + "0x030200", + 43, + ); + + let new_id_path = sysfs.vfio_pci_new_id(); + fs::create_dir_all(new_id_path.parent().unwrap()).unwrap(); + fs::write(&new_id_path, "").unwrap(); + + let remove_id_path = sysfs.vfio_pci_remove_id(); + fs::write(&remove_id_path, "").unwrap(); + + // Register the same vendor:device for two different BDFs + register_vfio_new_id(&sysfs, "0000:2d:00.0"); + let written = fs::read_to_string(&new_id_path).unwrap(); + assert_eq!( + written, "10de 26b8", + "first register should write to new_id" + ); + + // Clear the file to detect whether the second register writes + fs::write(&new_id_path, "").unwrap(); + register_vfio_new_id(&sysfs, "0000:3b:00.0"); + let written = fs::read_to_string(&new_id_path).unwrap(); + assert_eq!( + written, "", + "second register should NOT write to new_id (refcount > 1)" + ); + + // Deregister once — should NOT write to remove_id (one GPU still using it) + deregister_vfio_new_id(&sysfs, "0000:2d:00.0"); + let written = fs::read_to_string(&remove_id_path).unwrap(); + assert_eq!( + written, "", + "first deregister should NOT write to remove_id" + ); + + // Deregister again — should write to remove_id (last user) + deregister_vfio_new_id(&sysfs, "0000:3b:00.0"); + let written = fs::read_to_string(&remove_id_path).unwrap(); + assert_eq!( + written, "10de 26b8", + "second deregister SHOULD write to remove_id" + ); + } + // -- companion binding -------------------------------------------------- /// Helper to create a fake driver symlink for a mock PCI device. @@ -960,6 +1381,7 @@ mod tests { #[test] fn test_guard_drop_restores_companions() { + clear_vfio_id_refcount("10de 2684"); let (tmp, sysfs) = setup_mock_sysfs(); create_pci_device( &sysfs, @@ -1004,6 +1426,7 @@ mod tests { companion_bdfs: vec!["0000:2d:00.1".to_string()], sysfs: sysfs.clone(), disarmed: false, + vfio_id: None, }; // guard drops here — should attempt restore on both devices } @@ -1051,6 +1474,7 @@ mod tests { companion_bdfs: vec![], sysfs: sysfs.clone(), disarmed: false, + vfio_id: None, }; guard.disarm(); @@ -1060,10 +1484,51 @@ mod tests { assert_eq!(override_val, "vfio-pci"); } + // -- restore writes remove_id ------------------------------------------- + + #[test] + fn test_restore_gpu_deregisters_new_id_before_reprobe() { + clear_vfio_id_refcount("10de 26b7"); + let (tmp, sysfs) = setup_mock_sysfs(); + create_pci_device( + &sysfs, + tmp.path(), + "0000:2d:00.0", + "0x10de", + "0x26b7", + "0x030000", + 42, + ); + + let probe = sysfs.drivers_probe(); + fs::create_dir_all(probe.parent().unwrap()).unwrap(); + fs::write(&probe, "").unwrap(); + + let remove_id_path = sysfs.vfio_pci_remove_id(); + fs::create_dir_all(remove_id_path.parent().unwrap()).unwrap(); + fs::write(&remove_id_path, "").unwrap(); + + set_mock_driver(&sysfs, "0000:2d:00.0", "vfio-pci"); + fs::write( + sysfs.pci_device("0000:2d:00.0").join("driver_override"), + "vfio-pci", + ) + .unwrap(); + + restore_gpu_to_host_driver(&sysfs, "0000:2d:00.0").unwrap(); + + let written = fs::read_to_string(&remove_id_path).unwrap(); + assert_eq!( + written, "10de 26b7", + "remove_id should be written during restore" + ); + } + // -- reconcile_stale_bindings ------------------------------------------- #[test] fn test_reconcile_clears_stale_driver_override_when_not_on_vfio() { + clear_vfio_id_refcount("10de 2684"); let (tmp, sysfs) = setup_mock_sysfs(); create_pci_device( &sysfs, diff --git a/mise.lock b/mise.lock index d5d110bcc..2a6b7e02a 100644 --- a/mise.lock +++ b/mise.lock @@ -35,16 +35,36 @@ checksum = "sha256:afe92510c467f952a009b994f2d998ff8f9dd266dc26eca55d14a0dd46fec url = "https://github.com/anchore/syft/releases/download/v1.43.0/syft_1.43.0_linux_arm64.tar.gz" url_api = "https://api.github.com/repos/anchore/syft/releases/assets/402658323" +[tools."github:anchore/syft"."platforms.linux-arm64-musl"] +checksum = "sha256:afe92510c467f952a009b994f2d998ff8f9dd266dc26eca55d14a0dd46fec7f2" +url = "https://github.com/anchore/syft/releases/download/v1.43.0/syft_1.43.0_linux_arm64.tar.gz" +url_api = "https://api.github.com/repos/anchore/syft/releases/assets/402658323" + [tools."github:anchore/syft"."platforms.linux-x64"] checksum = "sha256:7b98251d2d08926bb5d4639b56b1f0996a58ef6667c5830e3fe3cd3ad5f4214a" url = "https://github.com/anchore/syft/releases/download/v1.43.0/syft_1.43.0_linux_amd64.tar.gz" url_api = "https://api.github.com/repos/anchore/syft/releases/assets/402658325" +[tools."github:anchore/syft"."platforms.linux-x64-musl"] +checksum = "sha256:7b98251d2d08926bb5d4639b56b1f0996a58ef6667c5830e3fe3cd3ad5f4214a" +url = "https://github.com/anchore/syft/releases/download/v1.43.0/syft_1.43.0_linux_amd64.tar.gz" +url_api = "https://api.github.com/repos/anchore/syft/releases/assets/402658325" + [tools."github:anchore/syft"."platforms.macos-arm64"] checksum = "sha256:3640e2181c8be7a56377f3c96e520d5380c924dbafd115ee3c8d45fcbc89cac2" url = "https://github.com/anchore/syft/releases/download/v1.43.0/syft_1.43.0_darwin_arm64.tar.gz" url_api = "https://api.github.com/repos/anchore/syft/releases/assets/402658324" +[tools."github:anchore/syft"."platforms.macos-x64"] +checksum = "sha256:08fd18f55037f999f50b2c2256a9285f0146978a0b16cdc58662ecdc85d0e3c0" +url = "https://github.com/anchore/syft/releases/download/v1.43.0/syft_1.43.0_darwin_amd64.tar.gz" +url_api = "https://api.github.com/repos/anchore/syft/releases/assets/402658329" + +[tools."github:anchore/syft"."platforms.windows-x64"] +checksum = "sha256:c51695d171c61460369dabdd5c71b8f350ef8618466818356a30808d7105c710" +url = "https://github.com/anchore/syft/releases/download/v1.43.0/syft_1.43.0_windows_amd64.zip" +url_api = "https://api.github.com/repos/anchore/syft/releases/assets/402658321" + [[tools."github:mozilla/sccache"]] version = "0.14.0" backend = "github:mozilla/sccache" @@ -54,16 +74,36 @@ checksum = "sha256:62a6c942c47c93333bc0174704800cef7edfa0416d08e1356c1d3e39f0b46 url = "https://github.com/mozilla/sccache/releases/download/v0.14.0/sccache-v0.14.0-aarch64-unknown-linux-musl.tar.gz" url_api = "https://api.github.com/repos/mozilla/sccache/releases/assets/353136010" +[tools."github:mozilla/sccache"."platforms.linux-arm64-musl"] +checksum = "sha256:62a6c942c47c93333bc0174704800cef7edfa0416d08e1356c1d3e39f0b462f2" +url = "https://github.com/mozilla/sccache/releases/download/v0.14.0/sccache-v0.14.0-aarch64-unknown-linux-musl.tar.gz" +url_api = "https://api.github.com/repos/mozilla/sccache/releases/assets/353136010" + [tools."github:mozilla/sccache"."platforms.linux-x64"] checksum = "sha256:8424b38cda4ecce616a1557d81328f3d7c96503a171eab79942fad618b42af44" url = "https://github.com/mozilla/sccache/releases/download/v0.14.0/sccache-v0.14.0-x86_64-unknown-linux-musl.tar.gz" url_api = "https://api.github.com/repos/mozilla/sccache/releases/assets/353136108" +[tools."github:mozilla/sccache"."platforms.linux-x64-musl"] +checksum = "sha256:8424b38cda4ecce616a1557d81328f3d7c96503a171eab79942fad618b42af44" +url = "https://github.com/mozilla/sccache/releases/download/v0.14.0/sccache-v0.14.0-x86_64-unknown-linux-musl.tar.gz" +url_api = "https://api.github.com/repos/mozilla/sccache/releases/assets/353136108" + [tools."github:mozilla/sccache"."platforms.macos-arm64"] checksum = "sha256:a781e8018260ab128e7690d8497736fa231b6ca895d57131d5b5b966ca987594" url = "https://github.com/mozilla/sccache/releases/download/v0.14.0/sccache-v0.14.0-aarch64-apple-darwin.tar.gz" url_api = "https://api.github.com/repos/mozilla/sccache/releases/assets/353135984" +[tools."github:mozilla/sccache"."platforms.macos-x64"] +checksum = "sha256:f86c5ecf9b9a1aee53022601725c5cea0e1d9318d80a8233017101063936ab62" +url = "https://github.com/mozilla/sccache/releases/download/v0.14.0/sccache-v0.14.0-x86_64-apple-darwin.tar.gz" +url_api = "https://api.github.com/repos/mozilla/sccache/releases/assets/353136084" + +[tools."github:mozilla/sccache"."platforms.windows-x64"] +checksum = "sha256:74a3ffd4207e8e0e62af7747bd03b42deab0f6dabc7ef0a8cdd950f83f1037c8" +url = "https://github.com/mozilla/sccache/releases/download/v0.14.0/sccache-v0.14.0-x86_64-pc-windows-msvc.zip" +url_api = "https://api.github.com/repos/mozilla/sccache/releases/assets/353136140" + [[tools.helm]] version = "4.1.4" backend = "aqua:helm/helm" @@ -88,14 +128,30 @@ backend = "aqua:kubernetes/kubernetes/kubectl" checksum = "sha256:6a5a4cc4e396d7626a7a693a3044b51c75520f81db30fe6816c2554e53be336f" url = "https://dl.k8s.io/v1.35.4/bin/linux/arm64/kubectl" +[tools.kubectl."platforms.linux-arm64-musl"] +checksum = "sha256:6a5a4cc4e396d7626a7a693a3044b51c75520f81db30fe6816c2554e53be336f" +url = "https://dl.k8s.io/v1.35.4/bin/linux/arm64/kubectl" + [tools.kubectl."platforms.linux-x64"] checksum = "sha256:b529430df69a688fd61b64ad2299edb5fd71cb58be2a4779dba624c7d3510efd" url = "https://dl.k8s.io/v1.35.4/bin/linux/amd64/kubectl" +[tools.kubectl."platforms.linux-x64-musl"] +checksum = "sha256:b529430df69a688fd61b64ad2299edb5fd71cb58be2a4779dba624c7d3510efd" +url = "https://dl.k8s.io/v1.35.4/bin/linux/amd64/kubectl" + [tools.kubectl."platforms.macos-arm64"] checksum = "sha256:ec644a2473b64b486987f695dfb1867963ce6d42d267b86e944585a546f92b5d" url = "https://dl.k8s.io/v1.35.4/bin/darwin/arm64/kubectl" +[tools.kubectl."platforms.macos-x64"] +checksum = "sha256:dddb01bddb96f78e48e33105ccfa2feedff585a8b2e3b812f5d0f64c7403710a" +url = "https://dl.k8s.io/v1.35.4/bin/darwin/amd64/kubectl" + +[tools.kubectl."platforms.windows-x64"] +checksum = "sha256:d77d03309bd80de56dafe8ca59ff6f2076e2ed4ee61c6a94657a4b6e945210e6" +url = "https://dl.k8s.io/v1.35.4/bin/windows/amd64/kubectl.exe" + [[tools.node]] version = "24.15.0" backend = "core:node" @@ -104,14 +160,30 @@ backend = "core:node" checksum = "sha256:73afc234d558c24919875f51c2d1ea002a2ada4ea6f83601a383869fefa64eed" url = "https://nodejs.org/dist/v24.15.0/node-v24.15.0-linux-arm64.tar.gz" +[tools.node."platforms.linux-arm64-musl"] +checksum = "sha256:73afc234d558c24919875f51c2d1ea002a2ada4ea6f83601a383869fefa64eed" +url = "https://nodejs.org/dist/v24.15.0/node-v24.15.0-linux-arm64.tar.gz" + [tools.node."platforms.linux-x64"] checksum = "sha256:44836872d9aec49f1e6b52a9a922872db9a2b02d235a616a5681b6a85fec8d89" url = "https://nodejs.org/dist/v24.15.0/node-v24.15.0-linux-x64.tar.gz" +[tools.node."platforms.linux-x64-musl"] +checksum = "sha256:44836872d9aec49f1e6b52a9a922872db9a2b02d235a616a5681b6a85fec8d89" +url = "https://nodejs.org/dist/v24.15.0/node-v24.15.0-linux-x64.tar.gz" + [tools.node."platforms.macos-arm64"] checksum = "sha256:372331b969779ab5d15b949884fc6eaf88d5afe87bde8ba881d6400b9100ffc4" url = "https://nodejs.org/dist/v24.15.0/node-v24.15.0-darwin-arm64.tar.gz" +[tools.node."platforms.macos-x64"] +checksum = "sha256:ffd5ee293467927f3ee731a553eb88fd1f48cf74eebc2d74a6babe4af228673b" +url = "https://nodejs.org/dist/v24.15.0/node-v24.15.0-darwin-x64.tar.gz" + +[tools.node."platforms.windows-x64"] +checksum = "sha256:cc5149eabd53779ce1e7bdc5401643622d0c7e6800ade18928a767e940bb0e62" +url = "https://nodejs.org/dist/v24.15.0/node-v24.15.0-win-x64.zip" + [[tools."npm:markdownlint-cli2"]] version = "0.22.0" backend = "npm:markdownlint-cli2" @@ -144,16 +216,36 @@ checksum = "sha256:0556f1260a9a1fc83210dcecf9d4cbacf17eb4a684541c84798ffc8b4d618 url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260414/cpython-3.13.13+20260414-aarch64-unknown-linux-gnu-install_only_stripped.tar.gz" provenance = "github-attestations" +[tools.python."platforms.linux-arm64-musl"] +checksum = "sha256:0556f1260a9a1fc83210dcecf9d4cbacf17eb4a684541c84798ffc8b4d618c35" +url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260414/cpython-3.13.13+20260414-aarch64-unknown-linux-gnu-install_only_stripped.tar.gz" +provenance = "github-attestations" + [tools.python."platforms.linux-x64"] checksum = "sha256:13d3b6d15f4c3c1dd1955a3c81e06bdc5aef4cb5cb65076878374948be3b0412" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260414/cpython-3.13.13+20260414-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz" provenance = "github-attestations" +[tools.python."platforms.linux-x64-musl"] +checksum = "sha256:13d3b6d15f4c3c1dd1955a3c81e06bdc5aef4cb5cb65076878374948be3b0412" +url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260414/cpython-3.13.13+20260414-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz" +provenance = "github-attestations" + [tools.python."platforms.macos-arm64"] checksum = "sha256:874f9931ad40dcce38caf6f408aa7e10ec3d0dfce2184ba7af62a965a66b9cd9" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260414/cpython-3.13.13+20260414-aarch64-apple-darwin-install_only_stripped.tar.gz" provenance = "github-attestations" +[tools.python."platforms.macos-x64"] +checksum = "sha256:d34198cd856fa80ebf3aa821fe329a25fab66eeda44f72ac9576591282e31bb7" +url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260414/cpython-3.13.13+20260414-x86_64-apple-darwin-install_only_stripped.tar.gz" +provenance = "github-attestations" + +[tools.python."platforms.windows-x64"] +checksum = "sha256:b84dce293464cfd366ee792a3d5b42abe5174fc9cce733ba895b3ef467cb3161" +url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260414/cpython-3.13.13+20260414-x86_64-pc-windows-msvc-install_only_stripped.tar.gz" +provenance = "github-attestations" + [[tools.rust]] version = "stable" backend = "core:rust" @@ -167,16 +259,36 @@ checksum = "sha256:55bd1c1c10ec8b95a8c184f5e18b566703c6ab105f0fc118aaa4d748aabf2 url = "https://github.com/astral-sh/uv/releases/download/0.10.12/uv-aarch64-unknown-linux-musl.tar.gz" provenance = "github-attestations" +[tools.uv."platforms.linux-arm64-musl"] +checksum = "sha256:55bd1c1c10ec8b95a8c184f5e18b566703c6ab105f0fc118aaa4d748aabf28e4" +url = "https://github.com/astral-sh/uv/releases/download/0.10.12/uv-aarch64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" + [tools.uv."platforms.linux-x64"] checksum = "sha256:adccf40b5d1939a5e0093081ec2307ea24235adf7c2d96b122c561fa37711c46" url = "https://github.com/astral-sh/uv/releases/download/0.10.12/uv-x86_64-unknown-linux-musl.tar.gz" provenance = "github-attestations" +[tools.uv."platforms.linux-x64-musl"] +checksum = "sha256:adccf40b5d1939a5e0093081ec2307ea24235adf7c2d96b122c561fa37711c46" +url = "https://github.com/astral-sh/uv/releases/download/0.10.12/uv-x86_64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" + [tools.uv."platforms.macos-arm64"] checksum = "sha256:ae738b5661a900579ec621d3918c0ef17bdec0da2a8a6d8b161137cd15f25414" url = "https://github.com/astral-sh/uv/releases/download/0.10.12/uv-aarch64-apple-darwin.tar.gz" provenance = "github-attestations" +[tools.uv."platforms.macos-x64"] +checksum = "sha256:17443e293f2ae407bb2d8d34b875ebfe0ae01cf1296de5647e69e7b2e2b428f0" +url = "https://github.com/astral-sh/uv/releases/download/0.10.12/uv-x86_64-apple-darwin.tar.gz" +provenance = "github-attestations" + +[tools.uv."platforms.windows-x64"] +checksum = "sha256:4c1d55501869b3330d4aabf45ad6024ce2367e0f3af83344395702d272c22e88" +url = "https://github.com/astral-sh/uv/releases/download/0.10.12/uv-x86_64-pc-windows-msvc.zip" +provenance = "github-attestations" + [[tools.zig]] version = "0.14.1" backend = "core:zig" diff --git a/tasks/scripts/vm/build-rootfs-tarball.sh b/tasks/scripts/vm/build-rootfs-tarball.sh index 57b215aad..87abca27e 100755 --- a/tasks/scripts/vm/build-rootfs-tarball.sh +++ b/tasks/scripts/vm/build-rootfs-tarball.sh @@ -9,15 +9,20 @@ # 2. Compresses it to a zstd tarball for embedding # # Usage: -# ./build-rootfs-tarball.sh [--base] +# ./build-rootfs-tarball.sh [--base|--gpu|--gpu-cuda] # # Options: # --base Build a base rootfs (~200-300MB) without pre-loaded images. # First boot will be slower but binary size is much smaller. # Default: full rootfs with pre-loaded images (~2GB+). +# --gpu Build a GPU-augmented rootfs that layers kmod, nvidia kernel +# modules, and nvidia firmware on top of the base rootfs. +# Output: target/vm-runtime-compressed/rootfs-gpu.tar.zst +# --gpu-cuda Like --gpu but also includes CUDA driver libraries +# (libcuda.so, libnvidia-ptxjitcompiler.so) for CUDA workloads. # # The resulting tarball is placed at target/vm-runtime-compressed/rootfs.tar.zst -# for inclusion in the embedded binary build. +# (or rootfs-gpu.tar.zst for --gpu) for inclusion in the embedded binary build. set -euo pipefail @@ -28,19 +33,36 @@ ROOTFS_BUILD_DIR="${ROOT}/target/rootfs-build" OUTPUT_DIR="${ROOT}/target/vm-runtime-compressed" OUTPUT="${OUTPUT_DIR}/rootfs.tar.zst" +KERNEL_VERSION="6.12.76" +NVIDIA_MODULES_DIR="${ROOT}/target/libkrun-build/nvidia-modules" +NVIDIA_USERSPACE_DIR="${ROOT}/target/libkrun-build/nvidia-userspace" + # Parse arguments BASE_ONLY=false +GPU_BUILD=false +GPU_CUDA=false for arg in "$@"; do case "$arg" in --base) BASE_ONLY=true ;; + --gpu) + GPU_BUILD=true + ;; + --gpu-cuda) + GPU_CUDA=true + GPU_BUILD=true + ;; --help|-h) - echo "Usage: $0 [--base]" + echo "Usage: $0 [--base|--gpu|--gpu-cuda]" echo "" echo "Options:" - echo " --base Build base rootfs (~200-300MB) without pre-loaded images" - echo " First boot will be slower but binary size is much smaller" + echo " --base Build base rootfs (~200-300MB) without pre-loaded images" + echo " First boot will be slower but binary size is much smaller" + echo " --gpu Build GPU rootfs with kmod, nvidia modules, and firmware" + echo " Layers on top of base rootfs, output: rootfs-gpu.tar.zst" + echo " --gpu-cuda Like --gpu but also includes CUDA driver libraries" + echo " (libcuda.so, libnvidia-ptxjitcompiler.so)" exit 0 ;; *) @@ -51,6 +73,286 @@ for arg in "$@"; do esac done +if [ "$GPU_BUILD" = true ]; then + GPU_OUTPUT="${OUTPUT_DIR}/rootfs-gpu.tar.zst" + GPU_ROOTFS_DIR="${ROOT}/target/rootfs-gpu-build" + trap 'echo "ERROR: GPU rootfs build failed; cleaning up ${GPU_ROOTFS_DIR}" >&2; rm -rf "${GPU_ROOTFS_DIR}"' ERR + + echo "==> Building GPU rootfs for embedding" + echo " Build dir: ${GPU_ROOTFS_DIR}" + echo " Output: ${GPU_OUTPUT}" + echo "" + + # Build base rootfs first if it doesn't exist + if [ ! -d "${ROOTFS_BUILD_DIR}" ]; then + echo "==> Step 1/3: Base rootfs not found, building it first..." + "${ROOT}/crates/openshell-vm/scripts/build-rootfs.sh" --base "${ROOTFS_BUILD_DIR}" + echo "" + fi + + echo "==> Step 2/3: Layering GPU tools onto base rootfs..." + + rm -rf "${GPU_ROOTFS_DIR}" + cp -a "${ROOTFS_BUILD_DIR}" "${GPU_ROOTFS_DIR}" + + # --- kmod --- + KMOD_BIN="$(command -v kmod 2>/dev/null || true)" + if [ -z "${KMOD_BIN}" ]; then + echo "WARNING: kmod not found on host; skipping kmod installation" + else + echo " Installing kmod from ${KMOD_BIN}" + mkdir -p "${GPU_ROOTFS_DIR}/bin" + cp "${KMOD_BIN}" "${GPU_ROOTFS_DIR}/bin/kmod" + chmod 755 "${GPU_ROOTFS_DIR}/bin/kmod" + + # Copy shared libraries required by kmod (host and guest must share compatible glibc) + if command -v ldd &>/dev/null; then + mkdir -p "${GPU_ROOTFS_DIR}/lib" "${GPU_ROOTFS_DIR}/lib64" + ldd "${KMOD_BIN}" 2>/dev/null | while read -r line; do + lib_path="$(echo "${line}" | sed -n 's/.* => \(\/[^ ]*\).*/\1/p')" + if [ -n "${lib_path}" ] && [ -f "${lib_path}" ]; then + # Skip core system libraries that already exist in the base rootfs. + # The host glibc may be older and overwriting breaks rootfs binaries. + lib_basename="$(basename "${lib_path}")" + case "${lib_basename}" in + libc.so*|libm.so*|libpthread.so*|libdl.so*|librt.so*|ld-linux*) continue ;; + esac + lib_dir="$(dirname "${lib_path}")" + mkdir -p "${GPU_ROOTFS_DIR}${lib_dir}" + cp -Lf "${lib_path}" "${GPU_ROOTFS_DIR}${lib_path}" 2>/dev/null || true + fi + done + fi + + # Fix broken .so symlinks left by Docker export (e.g. libzstd.so.1.5.5 -> itself). + # These cause ELOOP when the dynamic linker resolves the SONAME chain. + # Use -xtype l to find symlinks whose targets are missing or circular. + find "${GPU_ROOTFS_DIR}" -xtype l -name '*.so*' 2>/dev/null | while read -r broken; do + sobase="$(basename "$broken" | sed 's/\.so.*/\.so/')" + host_real="$(find /usr/lib /lib -name "${sobase}*" -type f 2>/dev/null | head -1)" + if [ -n "$host_real" ]; then + rm -f "$broken" + cp -L "$host_real" "$broken" 2>/dev/null || true + fi + done || true + + mkdir -p "${GPU_ROOTFS_DIR}/usr/sbin" + for tool in modprobe insmod rmmod lsmod depmod; do + ln -sf ../../bin/kmod "${GPU_ROOTFS_DIR}/usr/sbin/${tool}" + done + echo " Created symlinks: modprobe insmod rmmod lsmod depmod -> ../../bin/kmod" + fi + + # --- nvidia kernel modules --- + MODULES_DST="${GPU_ROOTFS_DIR}/lib/modules/${KERNEL_VERSION}/kernel/drivers/video" + if [ -d "${NVIDIA_MODULES_DIR}" ]; then + ko_files=("${NVIDIA_MODULES_DIR}"/*.ko) + if [ -e "${ko_files[0]}" ]; then + mkdir -p "${MODULES_DST}" + cp "${NVIDIA_MODULES_DIR}"/*.ko "${MODULES_DST}/" + echo " Installed nvidia kernel modules into lib/modules/${KERNEL_VERSION}/kernel/drivers/video/" + ls -1 "${MODULES_DST}"/*.ko | xargs -I{} basename {} | sed 's/^/ /' + if command -v depmod &>/dev/null; then + depmod -b "${GPU_ROOTFS_DIR}" "${KERNEL_VERSION}" 2>/dev/null || true + echo " Generated modules.dep" + fi + else + echo "WARNING: ${NVIDIA_MODULES_DIR} exists but contains no .ko files" + fi + else + echo "WARNING: nvidia kernel modules not found at ${NVIDIA_MODULES_DIR}" + echo " GPU rootfs will not contain nvidia drivers" + fi + + # Determine the kernel module driver version so we can match firmware + userspace. + NV_DRIVER_VERSION="" + if command -v modinfo &>/dev/null && [ -f "${NVIDIA_MODULES_DIR}/nvidia.ko" ]; then + NV_DRIVER_VERSION="$(modinfo -F version "${NVIDIA_MODULES_DIR}/nvidia.ko" 2>/dev/null || true)" + fi + if [ -n "${NV_DRIVER_VERSION}" ]; then + echo " Kernel module driver version: ${NV_DRIVER_VERSION}" + fi + + # --- nvidia firmware (GSP) --- + # Prefer version-matched firmware from nvidia-firmware/ directory. + # Fall back to host /lib/firmware/nvidia if version-matched is unavailable. + rm -rf "${GPU_ROOTFS_DIR}/lib/firmware/nvidia" 2>/dev/null || true + NVIDIA_FW_MATCHED_DIR="${ROOT}/target/libkrun-build/nvidia-firmware/${NV_DRIVER_VERSION}" + FW_DST="${GPU_ROOTFS_DIR}/lib/firmware/nvidia/${NV_DRIVER_VERSION}" + if [ -n "${NV_DRIVER_VERSION}" ] && [ -d "${NVIDIA_FW_MATCHED_DIR}" ]; then + mkdir -p "${FW_DST}" + cp "${NVIDIA_FW_MATCHED_DIR}"/*.bin "${FW_DST}/" 2>/dev/null || true + echo " Installed nvidia firmware from ${NVIDIA_FW_MATCHED_DIR} (version-matched)" + else + HOST_FW_DIR="" + for candidate in /lib/firmware/nvidia /usr/lib/firmware/nvidia; do + if [ -d "${candidate}" ]; then + HOST_FW_DIR="${candidate}" + break + fi + done + if [ -n "${HOST_FW_DIR}" ]; then + mkdir -p "${GPU_ROOTFS_DIR}/lib/firmware/nvidia" + cp -r "${HOST_FW_DIR}"/* "${GPU_ROOTFS_DIR}/lib/firmware/nvidia/" 2>/dev/null || true + echo " Installed nvidia firmware from ${HOST_FW_DIR}" + if [ -n "${NV_DRIVER_VERSION}" ]; then + echo " WARNING: host firmware version may not match kernel module version ${NV_DRIVER_VERSION}" + fi + else + echo "WARNING: nvidia firmware not found" + echo " GPU guests may fail to initialize the GPU without GSP firmware" + fi + fi + + # --- nvidia userspace (nvidia-smi + NVML) --- + + # Remove any pre-existing nvidia userspace from the base rootfs to avoid + # version conflicts. The base image may ship nvidia-smi and libs from a + # different driver version than the kernel modules we're installing. + for search_dir in "${GPU_ROOTFS_DIR}/usr/lib/x86_64-linux-gnu" \ + "${GPU_ROOTFS_DIR}/usr/lib64" \ + "${GPU_ROOTFS_DIR}/usr/lib"; do + rm -f "${search_dir}"/libnvidia-ml.so* 2>/dev/null || true + rm -f "${search_dir}"/libcuda.so* 2>/dev/null || true + rm -f "${search_dir}"/libnvidia-ptxjitcompiler.so* 2>/dev/null || true + done + rm -f "${GPU_ROOTFS_DIR}/usr/bin/nvidia-smi" 2>/dev/null || true + echo " Cleaned pre-existing nvidia userspace from base rootfs" + + # Prefer pre-extracted version-matched userspace from nvidia-userspace/. + # Fall back to host binaries only if the pre-extracted ones don't exist. + if [ -f "${NVIDIA_USERSPACE_DIR}/nvidia-smi" ]; then + mkdir -p "${GPU_ROOTFS_DIR}/usr/bin" + cp "${NVIDIA_USERSPACE_DIR}/nvidia-smi" "${GPU_ROOTFS_DIR}/usr/bin/nvidia-smi" + chmod 755 "${GPU_ROOTFS_DIR}/usr/bin/nvidia-smi" + echo " Installed nvidia-smi from ${NVIDIA_USERSPACE_DIR} (version-matched)" + else + NV_SMI="$(command -v nvidia-smi 2>/dev/null || true)" + if [ -n "${NV_SMI}" ]; then + mkdir -p "${GPU_ROOTFS_DIR}/usr/bin" + cp "${NV_SMI}" "${GPU_ROOTFS_DIR}/usr/bin/nvidia-smi" + chmod 755 "${GPU_ROOTFS_DIR}/usr/bin/nvidia-smi" + echo " Installed nvidia-smi from host: ${NV_SMI}" + echo " WARNING: host nvidia-smi version may not match kernel module version ${NV_DRIVER_VERSION}" + else + echo "WARNING: nvidia-smi not found; GPU rootfs will use mknod fallback" + fi + fi + + # libnvidia-ml.so — required by nvidia-smi (dlopen'd at runtime) + if [ -f "${NVIDIA_USERSPACE_DIR}/libnvidia-ml.so.${NV_DRIVER_VERSION}" ]; then + NV_ML_REAL="${NVIDIA_USERSPACE_DIR}/libnvidia-ml.so.${NV_DRIVER_VERSION}" + NV_LIB_DEST="${GPU_ROOTFS_DIR}/usr/lib/x86_64-linux-gnu" + mkdir -p "${NV_LIB_DEST}" + cp "${NV_ML_REAL}" "${NV_LIB_DEST}/libnvidia-ml.so.${NV_DRIVER_VERSION}" + ln -sf "libnvidia-ml.so.${NV_DRIVER_VERSION}" "${NV_LIB_DEST}/libnvidia-ml.so.1" + ln -sf libnvidia-ml.so.1 "${NV_LIB_DEST}/libnvidia-ml.so" + echo " Installed libnvidia-ml.so.${NV_DRIVER_VERSION} (version-matched)" + else + NV_ML_REAL="" + for search_dir in /usr/lib/x86_64-linux-gnu /usr/lib64 /usr/lib; do + NV_ML_REAL="$(find "${search_dir}" -maxdepth 1 -name 'libnvidia-ml.so.*.*.*' -type f 2>/dev/null | head -1)" + [ -n "${NV_ML_REAL}" ] && break + done + if [ -n "${NV_ML_REAL}" ]; then + NV_LIB_DIR="$(dirname "${NV_ML_REAL}")" + mkdir -p "${GPU_ROOTFS_DIR}${NV_LIB_DIR}" + cp "${NV_ML_REAL}" "${GPU_ROOTFS_DIR}${NV_ML_REAL}" + ln -sf "$(basename "${NV_ML_REAL}")" "${GPU_ROOTFS_DIR}${NV_LIB_DIR}/libnvidia-ml.so.1" + ln -sf libnvidia-ml.so.1 "${GPU_ROOTFS_DIR}${NV_LIB_DIR}/libnvidia-ml.so" + echo " Installed libnvidia-ml.so from host: ${NV_ML_REAL}" + echo " WARNING: host library version may not match kernel module version ${NV_DRIVER_VERSION}" + else + echo "WARNING: libnvidia-ml.so not found; nvidia-smi may not work at runtime" + fi + fi + + # --- CUDA driver libraries (optional, via --gpu-cuda) --- + if [ "${GPU_CUDA}" = true ]; then + echo " Installing CUDA driver libraries..." + + # libcuda.so + if [ -f "${NVIDIA_USERSPACE_DIR}/libcuda.so.${NV_DRIVER_VERSION}" ]; then + NV_LIB_DEST="${GPU_ROOTFS_DIR}/usr/lib/x86_64-linux-gnu" + mkdir -p "${NV_LIB_DEST}" + cp "${NVIDIA_USERSPACE_DIR}/libcuda.so.${NV_DRIVER_VERSION}" "${NV_LIB_DEST}/" + ln -sf "libcuda.so.${NV_DRIVER_VERSION}" "${NV_LIB_DEST}/libcuda.so.1" + ln -sf libcuda.so.1 "${NV_LIB_DEST}/libcuda.so" + echo " Installed libcuda.so.${NV_DRIVER_VERSION} (version-matched)" + else + CUDA_REAL="" + for search_dir in /usr/lib/x86_64-linux-gnu /usr/lib64 /usr/lib; do + CUDA_REAL="$(find "${search_dir}" -maxdepth 1 -name 'libcuda.so.*.*.*' -type f 2>/dev/null | head -1)" + [ -n "${CUDA_REAL}" ] && break + done + if [ -n "${CUDA_REAL}" ]; then + CUDA_LIB_DIR="$(dirname "${CUDA_REAL}")" + mkdir -p "${GPU_ROOTFS_DIR}${CUDA_LIB_DIR}" + cp "${CUDA_REAL}" "${GPU_ROOTFS_DIR}${CUDA_REAL}" + ln -sf "$(basename "${CUDA_REAL}")" "${GPU_ROOTFS_DIR}${CUDA_LIB_DIR}/libcuda.so.1" + ln -sf libcuda.so.1 "${GPU_ROOTFS_DIR}${CUDA_LIB_DIR}/libcuda.so" + echo " Installed libcuda.so from host: ${CUDA_REAL}" + echo " WARNING: host library version may not match kernel module version ${NV_DRIVER_VERSION}" + else + echo "WARNING: libcuda.so not found; CUDA workloads will not work" + fi + fi + + # libnvidia-ptxjitcompiler.so + if [ -f "${NVIDIA_USERSPACE_DIR}/libnvidia-ptxjitcompiler.so.${NV_DRIVER_VERSION}" ]; then + NV_LIB_DEST="${GPU_ROOTFS_DIR}/usr/lib/x86_64-linux-gnu" + mkdir -p "${NV_LIB_DEST}" + cp "${NVIDIA_USERSPACE_DIR}/libnvidia-ptxjitcompiler.so.${NV_DRIVER_VERSION}" "${NV_LIB_DEST}/" + ln -sf "libnvidia-ptxjitcompiler.so.${NV_DRIVER_VERSION}" "${NV_LIB_DEST}/libnvidia-ptxjitcompiler.so.1" + ln -sf libnvidia-ptxjitcompiler.so.1 "${NV_LIB_DEST}/libnvidia-ptxjitcompiler.so" + echo " Installed libnvidia-ptxjitcompiler.so.${NV_DRIVER_VERSION} (version-matched)" + else + PTX_REAL="" + for search_dir in /usr/lib/x86_64-linux-gnu /usr/lib64 /usr/lib; do + PTX_REAL="$(find "${search_dir}" -maxdepth 1 -name 'libnvidia-ptxjitcompiler.so.*.*.*' -type f 2>/dev/null | head -1)" + [ -n "${PTX_REAL}" ] && break + done + if [ -n "${PTX_REAL}" ]; then + PTX_LIB_DIR="$(dirname "${PTX_REAL}")" + mkdir -p "${GPU_ROOTFS_DIR}${PTX_LIB_DIR}" + cp "${PTX_REAL}" "${GPU_ROOTFS_DIR}${PTX_REAL}" + ln -sf "$(basename "${PTX_REAL}")" "${GPU_ROOTFS_DIR}${PTX_LIB_DIR}/libnvidia-ptxjitcompiler.so.1" + ln -sf libnvidia-ptxjitcompiler.so.1 "${GPU_ROOTFS_DIR}${PTX_LIB_DIR}/libnvidia-ptxjitcompiler.so" + echo " Installed libnvidia-ptxjitcompiler.so from host: ${PTX_REAL}" + echo " WARNING: host library version may not match kernel module version ${NV_DRIVER_VERSION}" + else + echo "WARNING: libnvidia-ptxjitcompiler.so not found; PTX JIT may not work" + fi + fi + fi + + # Ensure nvidia library path is in ld.so.conf for dlopen resolution + mkdir -p "${GPU_ROOTFS_DIR}/etc/ld.so.conf.d" + echo "/usr/lib/x86_64-linux-gnu" > "${GPU_ROOTFS_DIR}/etc/ld.so.conf.d/nvidia.conf" + if command -v ldconfig &>/dev/null; then + ldconfig -r "${GPU_ROOTFS_DIR}" 2>/dev/null || true + fi + + echo "" + echo "==> Step 3/3: Compressing GPU rootfs to tarball..." + mkdir -p "${OUTPUT_DIR}" + rm -f "${GPU_OUTPUT}" + + echo " Uncompressed size: $(du -sh "${GPU_ROOTFS_DIR}" | cut -f1)" + echo " Compressing with zstd (level 3)..." + tar -C "${GPU_ROOTFS_DIR}" -cf - . | zstd -3 -T0 -o "${GPU_OUTPUT}" + + echo "" + echo "==> GPU rootfs tarball created successfully!" + echo " Output: ${GPU_OUTPUT}" + echo " Compressed: $(du -sh "${GPU_OUTPUT}" | cut -f1)" + echo " Type: gpu (kmod + nvidia modules + firmware)" + echo "" + echo "Next step: mise run vm:build" + trap - ERR + exit 0 +fi + # Check if container engine is running if ! ce info &>/dev/null; then echo "Error: container engine is not running" >&2 @@ -65,7 +367,6 @@ if [ "$BASE_ONLY" = true ]; then echo " Mode: base (no pre-loaded images, ~200-300MB)" echo "" - # Build base rootfs echo "==> Step 1/2: Building base rootfs..." "${ROOT}/crates/openshell-vm/scripts/build-rootfs.sh" --base "${ROOTFS_BUILD_DIR}" else @@ -75,29 +376,23 @@ else echo " Mode: full (pre-loaded images, pre-initialized, ~2GB+)" echo "" - # Build full rootfs echo "==> Step 1/2: Building full rootfs (this may take 10-15 minutes)..." "${ROOT}/crates/openshell-vm/scripts/build-rootfs.sh" "${ROOTFS_BUILD_DIR}" fi -# Compress to tarball echo "" echo "==> Step 2/2: Compressing rootfs to tarball..." mkdir -p "${OUTPUT_DIR}" -# Remove existing tarball if present rm -f "${OUTPUT}" -# Get uncompressed size for display echo " Uncompressed size: $(du -sh "${ROOTFS_BUILD_DIR}" | cut -f1)" -# Create tarball with zstd compression # -19 = high compression (slower but smaller) # -T0 = use all available threads echo " Compressing with zstd (level 19, this may take a few minutes)..." tar -C "${ROOTFS_BUILD_DIR}" -cf - . | zstd -19 -T0 -o "${OUTPUT}" -# Report results echo "" echo "==> Rootfs tarball created successfully!" echo " Output: ${OUTPUT}" diff --git a/tasks/scripts/vm/compress-vm-runtime.sh b/tasks/scripts/vm/compress-vm-runtime.sh index efada8a2e..db5fbbd5b 100755 --- a/tasks/scripts/vm/compress-vm-runtime.sh +++ b/tasks/scripts/vm/compress-vm-runtime.sh @@ -90,9 +90,9 @@ if [ -z "${VM_RUNTIME_TARBALL:-}" ] && _check_compressed_artifacts "$OUTPUT_DIR" mkdir -p "$WORK_DIR" for f in "${OUTPUT_DIR}"/*.zst; do [ -f "$f" ] || continue + [ -s "$f" ] || continue name="$(basename "${f%.zst}")" - # Skip rootfs tarball — bundle-vm-runtime.sh doesn't need it - [[ "$name" == rootfs.tar ]] && continue + [[ "$name" == rootfs*.tar ]] && continue zstd -d "$f" -o "${WORK_DIR}/${name}" -f -q chmod 0755 "${WORK_DIR}/${name}" done