diff --git a/cmd/ateapi/internal/controlapi/gpu.go b/cmd/ateapi/internal/controlapi/gpu.go new file mode 100644 index 0000000..485970d --- /dev/null +++ b/cmd/ateapi/internal/controlapi/gpu.go @@ -0,0 +1,27 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +package controlapi + +import ( + "github.com/agent-substrate/substrate/internal/proto/ateletpb" + "github.com/agent-substrate/substrate/pkg/api/v1alpha1" +) + +func toAteletGpuSpec(r *v1alpha1.ContainerResources) *ateletpb.GpuSpec { + if r == nil || r.GPU == nil { + return nil + } + g := r.GPU + return &ateletpb.GpuSpec{ + Count: g.Count, + Device: g.Device, + DriverCapabilities: g.DriverCapabilities, + DriverVersion: g.DriverVersion, + } +} diff --git a/cmd/ateapi/internal/controlapi/workflow_resume.go b/cmd/ateapi/internal/controlapi/workflow_resume.go index 1401461..b68fa5f 100644 --- a/cmd/ateapi/internal/controlapi/workflow_resume.go +++ b/cmd/ateapi/internal/controlapi/workflow_resume.go @@ -179,6 +179,7 @@ func (s *CallAteletRestoreStep) Execute(ctx context.Context, input *ResumeInput, Name: ctr.Name, Image: ctr.Image, Command: ctr.Command, + Gpu: toAteletGpuSpec(ctr.Resources), } for _, env := range ctr.Env { ateletEnv := &ateletpb.EnvEntry{ diff --git a/cmd/ateapi/internal/controlapi/workflow_suspend.go b/cmd/ateapi/internal/controlapi/workflow_suspend.go index 9c55652..95517ac 100644 --- a/cmd/ateapi/internal/controlapi/workflow_suspend.go +++ b/cmd/ateapi/internal/controlapi/workflow_suspend.go @@ -152,6 +152,7 @@ func (s *CallAteletSuspendStep) Execute(ctx context.Context, input *SuspendInput Name: ctr.Name, Image: ctr.Image, Command: ctr.Command, + Gpu: toAteletGpuSpec(ctr.Resources), } for _, env := range ctr.Env { ateletEnv := &ateletpb.EnvEntry{ diff --git a/cmd/atelet/main.go b/cmd/atelet/main.go index bd6611d..33271cc 100644 --- a/cmd/atelet/main.go +++ b/cmd/atelet/main.go @@ -395,6 +395,7 @@ func (s *AteomHerder) Run(ctx context.Context, req *ateletpb.RunRequest) (*atele "io.kubernetes.cri.container-name": "pause", }, netnsPath, + firstGpuSpec(req.GetSpec().GetContainers()), ); err != nil { return fmt.Errorf("while creating pause OCI bundle: %w", err) } @@ -426,6 +427,7 @@ func (s *AteomHerder) Run(ctx context.Context, req *ateletpb.RunRequest) (*atele "io.kubernetes.cri.container-name": ctr.GetName(), }, netnsPath, + ctr.GetGpu(), ); err != nil { return fmt.Errorf("while creating %q OCI bundle: %w", ctr.GetName(), err) } @@ -456,6 +458,7 @@ func (s *AteomHerder) Run(ctx context.Context, req *ateletpb.RunRequest) (*atele for _, ctr := range req.GetSpec().GetContainers() { ateomCtr := &ateompb.Container{ Name: ctr.GetName(), + Gpu: toAteomGpuSpec(ctr.GetGpu()), } ateomReq.GetSpec().Containers = append(ateomReq.GetSpec().Containers, ateomCtr) } @@ -511,6 +514,7 @@ func (s *AteomHerder) Checkpoint(ctx context.Context, req *ateletpb.CheckpointRe for _, ctr := range req.GetSpec().GetContainers() { ateomCtr := &ateompb.Container{ Name: ctr.GetName(), + Gpu: toAteomGpuSpec(ctr.GetGpu()), } ateomReq.GetSpec().Containers = append(ateomReq.GetSpec().Containers, ateomCtr) } @@ -644,6 +648,7 @@ func (s *AteomHerder) Restore(ctx context.Context, req *ateletpb.RestoreRequest) "io.kubernetes.cri.container-name": "pause", }, netnsPath, + firstGpuSpec(req.GetSpec().GetContainers()), ); err != nil { return fmt.Errorf("while creating pause OCI bundle: %w", err) } @@ -675,6 +680,7 @@ func (s *AteomHerder) Restore(ctx context.Context, req *ateletpb.RestoreRequest) "io.kubernetes.cri.container-name": ctr.GetName(), }, netnsPath, + ctr.GetGpu(), ); err != nil { return fmt.Errorf("while creating %q OCI bundle: %w", ctr.GetName(), err) } @@ -705,6 +711,7 @@ func (s *AteomHerder) Restore(ctx context.Context, req *ateletpb.RestoreRequest) for _, ctr := range req.GetSpec().GetContainers() { ateomCtr := &ateompb.Container{ Name: ctr.GetName(), + Gpu: toAteomGpuSpec(ctr.GetGpu()), } ateomReq.GetSpec().Containers = append(ateomReq.GetSpec().Containers, ateomCtr) } @@ -716,6 +723,27 @@ func (s *AteomHerder) Restore(ctx context.Context, req *ateletpb.RestoreRequest) return &ateletpb.RestoreResponse{}, nil } +func toAteomGpuSpec(g *ateletpb.GpuSpec) *ateompb.GpuSpec { + if g == nil { + return nil + } + return &ateompb.GpuSpec{ + Count: g.GetCount(), + Device: g.GetDevice(), + DriverCapabilities: g.GetDriverCapabilities(), + DriverVersion: g.GetDriverVersion(), + } +} + +func firstGpuSpec(containers []*ateletpb.Container) *ateletpb.GpuSpec { + for _, c := range containers { + if g := c.GetGpu(); g != nil { + return g + } + } + return nil +} + type AteomDialer struct { conns *lru.Cache } diff --git a/cmd/atelet/oci.go b/cmd/atelet/oci.go index a2ae14c..634a9ea 100644 --- a/cmd/atelet/oci.go +++ b/cmd/atelet/oci.go @@ -28,12 +28,14 @@ import ( "github.com/agent-substrate/substrate/internal/ateompath" "github.com/agent-substrate/substrate/internal/memorypullcache" + "github.com/agent-substrate/substrate/internal/proto/ateletpb" "github.com/opencontainers/runtime-spec/specs-go" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" + "golang.org/x/sys/unix" ) -func prepareOCIDirectory(ctx context.Context, pullCache *memorypullcache.MemoryPullCache, actorTemplateNamespace, actorTemplateName, actorID, containerName, ref string, args []string, env []string, annotations map[string]string, netns string) error { +func prepareOCIDirectory(ctx context.Context, pullCache *memorypullcache.MemoryPullCache, actorTemplateNamespace, actorTemplateName, actorID, containerName, ref string, args []string, env []string, annotations map[string]string, netns string, gpu *ateletpb.GpuSpec) error { tracer := otel.Tracer("prepareOCIDirectory") ctx, span := tracer.Start(ctx, "prepareOCIDirectory") @@ -162,6 +164,15 @@ func prepareOCIDirectory(ctx context.Context, pullCache *memorypullcache.MemoryP }, Annotations: annotations, } + + if gpu != nil { + if err := addGPUToOCISpec(ociSpec); err != nil { + return fmt.Errorf("while adding GPU passthrough to OCI spec: %w", err) + } + if err := injectNVIDIAAssetsIntoRootfs(ctx, rootPath); err != nil { + return fmt.Errorf("while injecting NVIDIA driver assets into rootfs: %w", err) + } + } ociSpecBytes, err := json.MarshalIndent(ociSpec, "", " ") if err != nil { return fmt.Errorf("while marshaling OCI spec: %w", err) @@ -279,3 +290,168 @@ func untar(ctx context.Context, tarData io.Reader, rootPath string) error { return nil } + +var gpuDevicePaths = []string{ + "/dev/nvidiactl", + "/dev/nvidia-uvm", + "/dev/nvidia-uvm-tools", + "/dev/nvidia-modeset", +} + +func addGPUToOCISpec(spec *specs.Spec) error { + paths := append([]string{}, gpuDevicePaths...) + entries, err := os.ReadDir("/dev") + if err != nil { + return fmt.Errorf("reading /dev: %w", err) + } + for _, e := range entries { + n := e.Name() + if len(n) > 6 && n[:6] == "nvidia" && n[6] >= '0' && n[6] <= '9' { + paths = append(paths, "/dev/"+n) + } + } + for _, p := range paths { + var st unix.Stat_t + if err := unix.Stat(p, &st); err != nil { + if errors.Is(err, os.ErrNotExist) { + continue + } + return fmt.Errorf("stat %s: %w", p, err) + } + major := int64(unix.Major(uint64(st.Rdev))) //nolint:gosec + minor := int64(unix.Minor(uint64(st.Rdev))) //nolint:gosec + mode := os.FileMode(st.Mode & 0o777) //nolint:gosec + uid := st.Uid + gid := st.Gid + spec.Linux.Devices = append(spec.Linux.Devices, specs.LinuxDevice{ + Path: p, Type: "c", Major: major, Minor: minor, + FileMode: &mode, UID: &uid, GID: &gid, + }) + if spec.Linux.Resources == nil { + spec.Linux.Resources = &specs.LinuxResources{} + } + allow := true + access := "rwm" + spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, + specs.LinuxDeviceCgroup{Allow: allow, Type: "c", Major: &major, Minor: &minor, Access: access}, + ) + } + for _, name := range []string{"cuda-checkpoint", "cuda-checkpoint-wrapper.sh"} { + dest := "/usr/local/bin/" + name + // Source paths: prefer the shared /run/ateom-gvisor/static-files + // drop (visible inside both atelet and kind-node), fall back to + // /usr/local/bin if the operator installed it system-wide. + candidates := []string{"/run/ateom-gvisor/static-files/" + name, dest} + for _, src := range candidates { + if _, err := os.Stat(src); err != nil { + continue + } + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: dest, Type: "bind", Source: src, + Options: []string{"ro", "bind"}, + }) + break + } + } + return nil +} + +// nvidiaLibsStagingDir is where setup-host.sh stages the host's NVIDIA +// driver libs (libcuda.so., libnvidia-ml.so., …) plus their +// SONAME / dev symlinks. atelet copies them into each actor's rootfs at +// sandbox-create time so the workload image doesn't have to bake them in. +// +// This is the substrate-side equivalent of what +// `nvidia-container-cli configure --compute --utility --device=all` does +// in the standard docker+nvidia-container-runtime flow. We replicate the +// effect in Go rather than exec'ing nvidia-container-cli because atelet +// runs on `distroless/static-debian13` and has no dynamic linker for the +// `nvidia-container-cli` binary's libnvidia-container.so.1 dep. +const nvidiaLibsStagingDir = "/run/ateom-gvisor/static-files/nvidia-libs" + +// rootfsNVIDIALibDest is where libcuda.so. et al. need to land inside +// the sandbox rootfs. /etc/ld.so.cache on every glibc distro searches +// this path, so dlopen("libcuda.so.1") just works. +const rootfsNVIDIALibDest = "/usr/lib/x86_64-linux-gnu" + +// injectNVIDIAAssetsIntoRootfs walks nvidiaLibsStagingDir and mirrors +// every entry into /usr/lib/x86_64-linux-gnu — preserving +// symlinks as symlinks and copying real files byte-for-byte. Hard-fails +// if the staging dir is missing or empty so an operator misconfiguration +// surfaces immediately instead of crashing inside the sandbox. +func injectNVIDIAAssetsIntoRootfs(ctx context.Context, rootfsPath string) error { + tracer := otel.Tracer("prepareOCIDirectory") + _, span := tracer.Start(ctx, "injectNVIDIAAssetsIntoRootfs") + defer span.End() + + entries, err := os.ReadDir(nvidiaLibsStagingDir) + if err != nil { + return fmt.Errorf("reading NVIDIA libs staging dir %q (run setup-host.sh): %w", nvidiaLibsStagingDir, err) + } + if len(entries) == 0 { + return fmt.Errorf("NVIDIA libs staging dir %q is empty — re-run setup-host.sh", nvidiaLibsStagingDir) + } + + destDir := filepath.Join(rootfsPath, rootfsNVIDIALibDest) + if err := os.MkdirAll(destDir, 0o755); err != nil { + return fmt.Errorf("creating dest dir %q: %w", destDir, err) + } + + var copied, linked int + for _, e := range entries { + name := e.Name() + src := filepath.Join(nvidiaLibsStagingDir, name) + dst := filepath.Join(destDir, name) + + info, err := os.Lstat(src) + if err != nil { + return fmt.Errorf("lstat %q: %w", src, err) + } + + _ = os.Remove(dst) + + switch { + case info.Mode()&os.ModeSymlink != 0: + target, err := os.Readlink(src) + if err != nil { + return fmt.Errorf("readlink %q: %w", src, err) + } + if err := os.Symlink(target, dst); err != nil { + return fmt.Errorf("symlink %q -> %q: %w", dst, target, err) + } + linked++ + case info.Mode().IsRegular(): + if err := copyRegularFile(src, dst, info.Mode().Perm()); err != nil { + return fmt.Errorf("copy %q -> %q: %w", src, dst, err) + } + copied++ + default: + return fmt.Errorf("unexpected file type in %q: %s", src, info.Mode()) + } + } + + slog.InfoContext(ctx, "Injected NVIDIA driver assets into rootfs", + slog.String("source", nvidiaLibsStagingDir), + slog.String("dest", destDir), + slog.Int("files_copied", copied), + slog.Int("symlinks_created", linked), + ) + return nil +} + +func copyRegularFile(src, dst string, mode os.FileMode) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return err + } + if _, err := io.Copy(out, in); err != nil { + out.Close() + return err + } + return out.Close() +} diff --git a/cmd/ateom-gvisor/main.go b/cmd/ateom-gvisor/main.go index 08ef3cf..658dd75 100644 --- a/cmd/ateom-gvisor/main.go +++ b/cmd/ateom-gvisor/main.go @@ -170,6 +170,15 @@ func initTracing(ctx context.Context) (*sdktrace.TracerProvider, error) { return tp, nil } +func firstGPUSpec(containers []*ateompb.Container) *ateompb.GpuSpec { + for _, c := range containers { + if g := c.GetGpu(); g != nil { + return g + } + } + return nil +} + // AteomService is a service for shepherding single microvm. type AteomService struct { ateompb.UnimplementedAteomServer @@ -291,6 +300,7 @@ func (s *AteomService) RunWorkload(ctx context.Context, req *ateompb.RunWorkload actorTemplateNamespace: req.GetActorTemplateNamespace(), actorTemplateName: req.GetActorTemplateName(), actorID: req.GetActorId(), + gpu: firstGPUSpec(req.GetSpec().GetContainers()), } // Create and start pause container @@ -338,6 +348,7 @@ func (s *AteomService) CheckpointWorkload(ctx context.Context, req *ateompb.Chec actorTemplateNamespace: req.GetActorTemplateNamespace(), actorTemplateName: req.GetActorTemplateName(), actorID: req.GetActorId(), + gpu: firstGPUSpec(req.GetSpec().GetContainers()), } checkpointPath := ateompath.CheckpointDir(req.GetActorTemplateNamespace(), req.GetActorTemplateName(), req.GetActorId()) @@ -345,6 +356,13 @@ func (s *AteomService) CheckpointWorkload(ctx context.Context, req *ateompb.Chec return nil, fmt.Errorf("while creating checkpoint directory: %w", err) } + // Drain CUDA state out of nvproxy clients via cuda-checkpoint + // inside the supervisor container. Without this, runsc checkpoint + // returns "can't save with live nvproxy clients". + if err := rcmd.cmdDrainCUDA(ctx); err != nil { + return nil, fmt.Errorf("while draining CUDA: %w", err) + } + // Checkpoint pause container (root of the sandbox) if err := rcmd.cmdCheckpoint(ctx, "pause", checkpointPath); err != nil { return nil, fmt.Errorf("while checkpointing pause: %w", err) @@ -469,6 +487,7 @@ func (s *AteomService) RestoreWorkload(ctx context.Context, req *ateompb.Restore actorTemplateNamespace: req.GetActorTemplateNamespace(), actorTemplateName: req.GetActorTemplateName(), actorID: req.GetActorId(), + gpu: firstGPUSpec(req.GetSpec().GetContainers()), } checkpointDir := ateompath.CheckpointDir(req.GetActorTemplateNamespace(), req.GetActorTemplateName(), req.GetActorId()) @@ -497,6 +516,12 @@ func (s *AteomService) RestoreWorkload(ctx context.Context, req *ateompb.Restore } } + // CUDA was drained by cmdDrainCUDA before checkpoint; toggle it + // back to running now that the supervisor is restored. + if err := rcmd.cmdUntoggleCUDA(ctx); err != nil { + return nil, fmt.Errorf("while untoggling CUDA: %w", err) + } + s.actorLogger.EmitLifecycleLog("Actor restored", req.GetActorId(), req.GetActorTemplateName(), req.GetActorTemplateNamespace()) return &ateompb.RestoreWorkloadResponse{}, nil diff --git a/cmd/ateom-gvisor/runsc.go b/cmd/ateom-gvisor/runsc.go index ea5de7a..36250c2 100644 --- a/cmd/ateom-gvisor/runsc.go +++ b/cmd/ateom-gvisor/runsc.go @@ -21,15 +21,99 @@ import ( "log/slog" "os" "os/exec" + "strings" "github.com/agent-substrate/substrate/internal/ateompath" + "github.com/agent-substrate/substrate/internal/proto/ateompb" ) +const cudaCheckpointWrapperPath = "/usr/local/bin/cuda-checkpoint-wrapper.sh" +const saveRestoreExecTimeout = "30s" // runsc wants a Go duration string, not ms. + type runsc struct { path string actorTemplateNamespace string actorTemplateName string actorID string + gpu *ateompb.GpuSpec +} + +func (r *runsc) gpuGlobalFlags() []string { + if r.gpu == nil { + return nil + } + flags := []string{"--nvproxy"} + if v := r.gpu.GetDriverVersion(); v != "" { + flags = append(flags, "--nvproxy-driver-version="+v) + } + if caps := r.gpu.GetDriverCapabilities(); len(caps) > 0 { + flags = append(flags, "--nvproxy-allowed-driver-capabilities="+strings.Join(caps, ",")) + } + return flags +} + +// gpuSaveRestoreFlags is intentionally nil. gVisor's runsc +// --save-restore-exec-argv runs the exec in the container being +// checkpointed (pause for substrate's root sandbox). pause is the +// k8s pause image, distroless, no /bin/sh. So a wrapper script +// can't execute there. We drain CUDA externally via cmdDrainCUDA +// (runsc exec supervisor cuda-checkpoint --toggle --pid 1) just +// before cmdCheckpoint. +func (r *runsc) gpuSaveRestoreFlags() []string { + return nil +} + +// cmdDrainCUDA runs cuda-checkpoint inside the supervisor sub-container +// to drain CUDA state out of all live nvproxy clients. Without this, +// `runsc checkpoint` returns "can't save with live nvproxy clients". +func (r *runsc) cmdDrainCUDA(ctx context.Context) error { + if r.gpu == nil { + return nil + } + slog.InfoContext(ctx, "About to drain CUDA via runsc exec supervisor cuda-checkpoint") + cmd := exec.CommandContext( + ctx, + r.path, + "-log-format", "json", + "--alsologtostderr", + "-root", ateompath.RunSCStateDir(r.actorTemplateNamespace, r.actorTemplateName, r.actorID), + "exec", + "--", // marker for argv passthrough + "supervisor", + "/usr/local/bin/cuda-checkpoint", "--toggle", "--pid", "1", + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("while running cuda-checkpoint drain: %w", err) + } + return nil +} + +// cmdUntoggleCUDA reverses cmdDrainCUDA after restore: cuda-checkpoint +// --toggle flips the locked CUDA state back to running. +func (r *runsc) cmdUntoggleCUDA(ctx context.Context) error { + if r.gpu == nil { + return nil + } + slog.InfoContext(ctx, "About to untoggle CUDA via runsc exec supervisor cuda-checkpoint") + cmd := exec.CommandContext( + ctx, + r.path, + "-log-format", "json", + "--alsologtostderr", + "-root", ateompath.RunSCStateDir(r.actorTemplateNamespace, r.actorTemplateName, r.actorID), + "exec", + "--", + "supervisor", + "/usr/local/bin/cuda-checkpoint", "--toggle", "--pid", "1", + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("while running cuda-checkpoint untoggle: %w", err) + } + return nil } func (r *runsc) cmdCreate(ctx context.Context, out io.Writer, containerName string) error { @@ -38,9 +122,7 @@ func (r *runsc) cmdCreate(ctx context.Context, out io.Writer, containerName stri slog.InfoContext(ctx, "About to run runsc create", slog.String("container", containerName)) - cmd := exec.CommandContext( - ctx, - r.path, + args := []string{ "-log-format", "json", "--alsologtostderr", // "-debug", @@ -49,11 +131,15 @@ func (r *runsc) cmdCreate(ctx context.Context, out io.Writer, containerName stri // "-log-packets", // "-strace", "-root", ateompath.RunSCStateDir(r.actorTemplateNamespace, r.actorTemplateName, r.actorID), + } + args = append(args, r.gpuGlobalFlags()...) + args = append(args, "create", "-bundle", ateompath.OCIBundlePath(r.actorTemplateNamespace, r.actorTemplateName, r.actorID, containerName), "-pid-file", ateompath.PIDFilePath(r.actorTemplateNamespace, r.actorTemplateName, r.actorID, containerName), containerName, // Name of the container ) + cmd := exec.CommandContext(ctx, r.path, args...) cmd.Stdout = out cmd.Stderr = out @@ -103,9 +189,7 @@ func (r *runsc) cmdCheckpoint(ctx context.Context, containerName, checkpointPath slog.InfoContext(ctx, "About to run runsc checkpoint", slog.String("container", containerName)) - cmd := exec.CommandContext( - ctx, - r.path, + args := []string{ "-log-format", "json", "--alsologtostderr", // "-debug", @@ -114,10 +198,12 @@ func (r *runsc) cmdCheckpoint(ctx context.Context, containerName, checkpointPath // "-log-packets", // "-strace", "-root", ateompath.RunSCStateDir(r.actorTemplateNamespace, r.actorTemplateName, r.actorID), - "checkpoint", - "-image-path", checkpointPath, - containerName, // Name of the container - ) + } + args = append(args, r.gpuGlobalFlags()...) + args = append(args, "checkpoint", "-image-path", checkpointPath) + args = append(args, r.gpuSaveRestoreFlags()...) + args = append(args, containerName) + cmd := exec.CommandContext(ctx, r.path, args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() @@ -135,9 +221,7 @@ func (r *runsc) cmdRestore(ctx context.Context, out io.Writer, containerName, ch slog.InfoContext(ctx, "About to run runsc restore", slog.String("container", containerName)) - cmd := exec.CommandContext( - ctx, - r.path, + args := []string{ "-log-format", "json", "--alsologtostderr", // "-debug", @@ -146,15 +230,28 @@ func (r *runsc) cmdRestore(ctx context.Context, out io.Writer, containerName, ch // "-log-packets", // "-strace", "-root", ateompath.RunSCStateDir(r.actorTemplateNamespace, r.actorTemplateName, r.actorID), + } + args = append(args, r.gpuGlobalFlags()...) + args = append(args, "restore", "-bundle", ateompath.OCIBundlePath(r.actorTemplateNamespace, r.actorTemplateName, r.actorID, containerName), "-image-path", checkpointPath, "-pid-file", ateompath.PIDFilePath(r.actorTemplateNamespace, r.actorTemplateName, r.actorID, containerName), + ) + if containerName == "pause" { + // --save-restore-exec-argv runs the wrapper once per sandbox, on + // the root container's restore. Sub-container restores must not + // re-invoke it -- the sandbox is already up and CUDA state has + // already been re-toggled. + args = append(args, r.gpuSaveRestoreFlags()...) + } + args = append(args, //"-background", //"-direct", // TODO(ateom): Reenable direct "-detach", containerName, ) + cmd := exec.CommandContext(ctx, r.path, args...) cmd.Stdout = out cmd.Stderr = out if err := cmd.Run(); err != nil { diff --git a/cmd/ateom-gvisor/runsc_test.go b/cmd/ateom-gvisor/runsc_test.go new file mode 100644 index 0000000..3b3b090 --- /dev/null +++ b/cmd/ateom-gvisor/runsc_test.go @@ -0,0 +1,138 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux + +package main + +import ( + "strings" + "testing" + + "github.com/agent-substrate/substrate/internal/proto/ateompb" +) + +func TestGpuGlobalFlags_NilReturnsNil(t *testing.T) { + r := &runsc{} + if got := r.gpuGlobalFlags(); got != nil { + t.Fatalf("expected nil for nil GPU spec, got %v", got) + } +} + +func TestGpuGlobalFlags_BareSpecEmitsNvproxyOnly(t *testing.T) { + r := &runsc{gpu: &ateompb.GpuSpec{Count: 1}} + got := r.gpuGlobalFlags() + if len(got) != 1 || got[0] != "--nvproxy" { + t.Fatalf("expected [--nvproxy], got %v", got) + } +} + +func TestGpuGlobalFlags_WithDriverVersion(t *testing.T) { + r := &runsc{gpu: &ateompb.GpuSpec{Count: 1, DriverVersion: "580.126.09"}} + got := strings.Join(r.gpuGlobalFlags(), " ") + want := "--nvproxy --nvproxy-driver-version=580.126.09" + if got != want { + t.Fatalf("flags: got %q want %q", got, want) + } +} + +func TestGpuGlobalFlags_WithCapabilities(t *testing.T) { + r := &runsc{gpu: &ateompb.GpuSpec{ + Count: 1, + DriverCapabilities: []string{"compute", "utility"}, + }} + got := strings.Join(r.gpuGlobalFlags(), " ") + want := "--nvproxy --nvproxy-allowed-driver-capabilities=compute,utility" + if got != want { + t.Fatalf("flags: got %q want %q", got, want) + } +} + +func TestGpuGlobalFlags_FullSpec(t *testing.T) { + r := &runsc{gpu: &ateompb.GpuSpec{ + Count: 1, + DriverVersion: "580.126.09", + DriverCapabilities: []string{"compute", "utility", "video"}, + }} + got := strings.Join(r.gpuGlobalFlags(), " ") + want := "--nvproxy --nvproxy-driver-version=580.126.09 --nvproxy-allowed-driver-capabilities=compute,utility,video" + if got != want { + t.Fatalf("flags: got %q want %q", got, want) + } +} + +func TestGpuSaveRestoreFlags_NilReturnsNil(t *testing.T) { + r := &runsc{} + if got := r.gpuSaveRestoreFlags(); got != nil { + t.Fatalf("expected nil for nil GPU spec, got %v", got) + } +} + +func TestGpuSaveRestoreFlags_DurationIsString(t *testing.T) { + // The runsc --save-restore-exec-timeout flag wants a Go duration + // string (e.g. "30s"), not a millisecond integer. Regression test + // for the parse error captured in the 2026-05-27 brev validation. + r := &runsc{gpu: &ateompb.GpuSpec{Count: 1}} + got := r.gpuSaveRestoreFlags() + if len(got) != 2 { + t.Fatalf("expected 2 flags, got %d: %v", len(got), got) + } + if got[0] != "--save-restore-exec-argv=/usr/local/bin/cuda-checkpoint-wrapper.sh" { + t.Errorf("argv flag wrong: %q", got[0]) + } + if got[1] != "--save-restore-exec-timeout=30s" { + t.Errorf("timeout flag wrong (must be a duration string, not ms): %q", got[1]) + } +} + +func TestFirstGPUSpec_None(t *testing.T) { + containers := []*ateompb.Container{ + {Name: "pause"}, + {Name: "app"}, + } + if g := firstGPUSpec(containers); g != nil { + t.Fatalf("expected nil when no container has GPU, got %v", g) + } +} + +func TestGpuSaveRestoreFlags_OnlyOnRootContainer(t *testing.T) { + // cmdRestore must only emit --save-restore-exec-argv on the pause + // (root) container; sub-container restores must not re-invoke the + // wrapper. Verified empirically on the L40S E2E run (2026-05-27): + // without this gate, the sub-container restore fails with + // "inconsistent private memory files on restore". + // + // gpuSaveRestoreFlags() itself doesn't know about container names; + // the gating lives in cmdRestore. This is a behavioural cross-check + // rather than a unit test of the helper. + r := &runsc{gpu: &ateompb.GpuSpec{Count: 1}} + if got := r.gpuSaveRestoreFlags(); len(got) == 0 { + t.Fatalf("expected non-empty flags for the root container") + } +} + +func TestFirstGPUSpec_FindsFirst(t *testing.T) { + containers := []*ateompb.Container{ + {Name: "pause"}, + {Name: "app", Gpu: &ateompb.GpuSpec{Count: 1, DriverVersion: "580.126.09"}}, + {Name: "sidecar"}, + } + g := firstGPUSpec(containers) + if g == nil { + t.Fatalf("expected non-nil GPU spec") + } + if g.GetDriverVersion() != "580.126.09" { + t.Fatalf("expected driver version from app container, got %q", g.GetDriverVersion()) + } +} diff --git a/hack/cuda-checkpoint-wrapper.sh b/hack/cuda-checkpoint-wrapper.sh new file mode 100755 index 0000000..0bc7e68 --- /dev/null +++ b/hack/cuda-checkpoint-wrapper.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# cuda-checkpoint wrapper invoked by runsc via --save-restore-exec-argv. +# Toggles every CUDA-touching PID inside the sandbox. Idempotent, so runs +# fine on both pre-save and post-restore invocations. +set -e + +CB=/usr/local/bin/cuda-checkpoint +[ -x "$CB" ] || { echo "wrapper: $CB missing" >&2; exit 1; } + +pids="" +for d in /proc/[0-9]*; do + pid=${d#/proc/} + [ "$pid" = "$$" ] && continue + [ -r "$d/maps" ] && grep -qE '(/dev/nvidia|libcuda\.so|libcudart\.so|libnvidia-ml\.so)' "$d/maps" 2>/dev/null && pids="$pids $pid" +done +[ -z "$pids" ] && { echo "wrapper: no CUDA pids" >&2; exit 0; } +for p in $pids; do + echo "wrapper: --toggle pid=$p" >&2 + "$CB" --toggle --pid "$p" +done diff --git a/internal/proto/ateletpb/atelet.pb.go b/internal/proto/ateletpb/atelet.pb.go index 404aa0d..75c6bf7 100644 --- a/internal/proto/ateletpb/atelet.pb.go +++ b/internal/proto/ateletpb/atelet.pb.go @@ -390,6 +390,7 @@ type Container struct { Image string `protobuf:"bytes,2,opt,name=image,proto3" json:"image,omitempty"` Command []string `protobuf:"bytes,3,rep,name=command,proto3" json:"command,omitempty"` Env []*EnvEntry `protobuf:"bytes,4,rep,name=env,proto3" json:"env,omitempty"` + Gpu *GpuSpec `protobuf:"bytes,5,opt,name=gpu,proto3" json:"gpu,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -452,6 +453,82 @@ func (x *Container) GetEnv() []*EnvEntry { return nil } +func (x *Container) GetGpu() *GpuSpec { + if x != nil { + return x.Gpu + } + return nil +} + +// GpuSpec mirrors v1alpha1.GPUResource on the ActorTemplate CRD. +type GpuSpec struct { + state protoimpl.MessageState `protogen:"open.v1"` + Count int32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` + Device string `protobuf:"bytes,2,opt,name=device,proto3" json:"device,omitempty"` + DriverCapabilities []string `protobuf:"bytes,3,rep,name=driver_capabilities,json=driverCapabilities,proto3" json:"driver_capabilities,omitempty"` + DriverVersion string `protobuf:"bytes,4,opt,name=driver_version,json=driverVersion,proto3" json:"driver_version,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GpuSpec) Reset() { + *x = GpuSpec{} + mi := &file_atelet_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GpuSpec) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GpuSpec) ProtoMessage() {} + +func (x *GpuSpec) ProtoReflect() protoreflect.Message { + mi := &file_atelet_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GpuSpec.ProtoReflect.Descriptor instead. +func (*GpuSpec) Descriptor() ([]byte, []int) { + return file_atelet_proto_rawDescGZIP(), []int{7} +} + +func (x *GpuSpec) GetCount() int32 { + if x != nil { + return x.Count + } + return 0 +} + +func (x *GpuSpec) GetDevice() string { + if x != nil { + return x.Device + } + return "" +} + +func (x *GpuSpec) GetDriverCapabilities() []string { + if x != nil { + return x.DriverCapabilities + } + return nil +} + +func (x *GpuSpec) GetDriverVersion() string { + if x != nil { + return x.DriverVersion + } + return "" +} + type EnvEntry struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` @@ -462,7 +539,7 @@ type EnvEntry struct { func (x *EnvEntry) Reset() { *x = EnvEntry{} - mi := &file_atelet_proto_msgTypes[7] + mi := &file_atelet_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -474,7 +551,7 @@ func (x *EnvEntry) String() string { func (*EnvEntry) ProtoMessage() {} func (x *EnvEntry) ProtoReflect() protoreflect.Message { - mi := &file_atelet_proto_msgTypes[7] + mi := &file_atelet_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -487,7 +564,7 @@ func (x *EnvEntry) ProtoReflect() protoreflect.Message { // Deprecated: Use EnvEntry.ProtoReflect.Descriptor instead. func (*EnvEntry) Descriptor() ([]byte, []int) { - return file_atelet_proto_rawDescGZIP(), []int{7} + return file_atelet_proto_rawDescGZIP(), []int{8} } func (x *EnvEntry) GetName() string { @@ -512,7 +589,7 @@ type RunResponse struct { func (x *RunResponse) Reset() { *x = RunResponse{} - mi := &file_atelet_proto_msgTypes[8] + mi := &file_atelet_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -524,7 +601,7 @@ func (x *RunResponse) String() string { func (*RunResponse) ProtoMessage() {} func (x *RunResponse) ProtoReflect() protoreflect.Message { - mi := &file_atelet_proto_msgTypes[8] + mi := &file_atelet_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -537,7 +614,7 @@ func (x *RunResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RunResponse.ProtoReflect.Descriptor instead. func (*RunResponse) Descriptor() ([]byte, []int) { - return file_atelet_proto_rawDescGZIP(), []int{8} + return file_atelet_proto_rawDescGZIP(), []int{9} } type CheckpointRequest struct { @@ -566,7 +643,7 @@ type CheckpointRequest struct { func (x *CheckpointRequest) Reset() { *x = CheckpointRequest{} - mi := &file_atelet_proto_msgTypes[9] + mi := &file_atelet_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -578,7 +655,7 @@ func (x *CheckpointRequest) String() string { func (*CheckpointRequest) ProtoMessage() {} func (x *CheckpointRequest) ProtoReflect() protoreflect.Message { - mi := &file_atelet_proto_msgTypes[9] + mi := &file_atelet_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -591,7 +668,7 @@ func (x *CheckpointRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CheckpointRequest.ProtoReflect.Descriptor instead. func (*CheckpointRequest) Descriptor() ([]byte, []int) { - return file_atelet_proto_rawDescGZIP(), []int{9} + return file_atelet_proto_rawDescGZIP(), []int{10} } func (x *CheckpointRequest) GetTargetAteomNamespace() string { @@ -658,7 +735,7 @@ type CheckpointResponse struct { func (x *CheckpointResponse) Reset() { *x = CheckpointResponse{} - mi := &file_atelet_proto_msgTypes[10] + mi := &file_atelet_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -670,7 +747,7 @@ func (x *CheckpointResponse) String() string { func (*CheckpointResponse) ProtoMessage() {} func (x *CheckpointResponse) ProtoReflect() protoreflect.Message { - mi := &file_atelet_proto_msgTypes[10] + mi := &file_atelet_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -683,7 +760,7 @@ func (x *CheckpointResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CheckpointResponse.ProtoReflect.Descriptor instead. func (*CheckpointResponse) Descriptor() ([]byte, []int) { - return file_atelet_proto_rawDescGZIP(), []int{10} + return file_atelet_proto_rawDescGZIP(), []int{11} } type RestoreRequest struct { @@ -703,7 +780,7 @@ type RestoreRequest struct { func (x *RestoreRequest) Reset() { *x = RestoreRequest{} - mi := &file_atelet_proto_msgTypes[11] + mi := &file_atelet_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -715,7 +792,7 @@ func (x *RestoreRequest) String() string { func (*RestoreRequest) ProtoMessage() {} func (x *RestoreRequest) ProtoReflect() protoreflect.Message { - mi := &file_atelet_proto_msgTypes[11] + mi := &file_atelet_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -728,7 +805,7 @@ func (x *RestoreRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RestoreRequest.ProtoReflect.Descriptor instead. func (*RestoreRequest) Descriptor() ([]byte, []int) { - return file_atelet_proto_rawDescGZIP(), []int{11} + return file_atelet_proto_rawDescGZIP(), []int{12} } func (x *RestoreRequest) GetTargetAteomNamespace() string { @@ -795,7 +872,7 @@ type RestoreResponse struct { func (x *RestoreResponse) Reset() { *x = RestoreResponse{} - mi := &file_atelet_proto_msgTypes[12] + mi := &file_atelet_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -807,7 +884,7 @@ func (x *RestoreResponse) String() string { func (*RestoreResponse) ProtoMessage() {} func (x *RestoreResponse) ProtoReflect() protoreflect.Message { - mi := &file_atelet_proto_msgTypes[12] + mi := &file_atelet_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -820,7 +897,7 @@ func (x *RestoreResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RestoreResponse.ProtoReflect.Descriptor instead. func (*RestoreResponse) Descriptor() ([]byte, []int) { - return file_atelet_proto_rawDescGZIP(), []int{12} + return file_atelet_proto_rawDescGZIP(), []int{13} } var File_atelet_proto protoreflect.FileDescriptor @@ -854,12 +931,18 @@ const file_atelet_proto_rawDesc = "" + "containers\x18\x01 \x03(\v2\x11.atelet.ContainerR\n" + "containers\x12\x1f\n" + "\vpause_image\x18\x02 \x01(\tR\n" + - "pauseImage\"s\n" + + "pauseImage\"\x96\x01\n" + "\tContainer\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n" + "\x05image\x18\x02 \x01(\tR\x05image\x12\x18\n" + "\acommand\x18\x03 \x03(\tR\acommand\x12\"\n" + - "\x03env\x18\x04 \x03(\v2\x10.atelet.EnvEntryR\x03env\"4\n" + + "\x03env\x18\x04 \x03(\v2\x10.atelet.EnvEntryR\x03env\x12!\n" + + "\x03gpu\x18\x05 \x01(\v2\x0f.atelet.GpuSpecR\x03gpu\"\x8f\x01\n" + + "\aGpuSpec\x12\x14\n" + + "\x05count\x18\x01 \x01(\x05R\x05count\x12\x16\n" + + "\x06device\x18\x02 \x01(\tR\x06device\x12/\n" + + "\x13driver_capabilities\x18\x03 \x03(\tR\x12driverCapabilities\x12%\n" + + "\x0edriver_version\x18\x04 \x01(\tR\rdriverVersion\"4\n" + "\bEnvEntry\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n" + "\x05value\x18\x02 \x01(\tR\x05value\"\r\n" + @@ -902,7 +985,7 @@ func file_atelet_proto_rawDescGZIP() []byte { return file_atelet_proto_rawDescData } -var file_atelet_proto_msgTypes = make([]protoimpl.MessageInfo, 13) +var file_atelet_proto_msgTypes = make([]protoimpl.MessageInfo, 14) var file_atelet_proto_goTypes = []any{ (*RunRequest)(nil), // 0: atelet.RunRequest (*GCPAuthenticationConfig)(nil), // 1: atelet.GCPAuthenticationConfig @@ -911,12 +994,13 @@ var file_atelet_proto_goTypes = []any{ (*RunscConfig)(nil), // 4: atelet.RunscConfig (*WorkloadSpec)(nil), // 5: atelet.WorkloadSpec (*Container)(nil), // 6: atelet.Container - (*EnvEntry)(nil), // 7: atelet.EnvEntry - (*RunResponse)(nil), // 8: atelet.RunResponse - (*CheckpointRequest)(nil), // 9: atelet.CheckpointRequest - (*CheckpointResponse)(nil), // 10: atelet.CheckpointResponse - (*RestoreRequest)(nil), // 11: atelet.RestoreRequest - (*RestoreResponse)(nil), // 12: atelet.RestoreResponse + (*GpuSpec)(nil), // 7: atelet.GpuSpec + (*EnvEntry)(nil), // 8: atelet.EnvEntry + (*RunResponse)(nil), // 9: atelet.RunResponse + (*CheckpointRequest)(nil), // 10: atelet.CheckpointRequest + (*CheckpointResponse)(nil), // 11: atelet.CheckpointResponse + (*RestoreRequest)(nil), // 12: atelet.RestoreRequest + (*RestoreResponse)(nil), // 13: atelet.RestoreResponse } var file_atelet_proto_depIdxs = []int32{ 4, // 0: atelet.RunRequest.runsc:type_name -> atelet.RunscConfig @@ -926,22 +1010,23 @@ var file_atelet_proto_depIdxs = []int32{ 3, // 4: atelet.RunscConfig.arm64:type_name -> atelet.RunscPlatformConfig 2, // 5: atelet.RunscConfig.authentication:type_name -> atelet.AuthenticationConfig 6, // 6: atelet.WorkloadSpec.containers:type_name -> atelet.Container - 7, // 7: atelet.Container.env:type_name -> atelet.EnvEntry - 4, // 8: atelet.CheckpointRequest.runsc:type_name -> atelet.RunscConfig - 5, // 9: atelet.CheckpointRequest.spec:type_name -> atelet.WorkloadSpec - 4, // 10: atelet.RestoreRequest.runsc:type_name -> atelet.RunscConfig - 5, // 11: atelet.RestoreRequest.spec:type_name -> atelet.WorkloadSpec - 0, // 12: atelet.AteomHerder.Run:input_type -> atelet.RunRequest - 9, // 13: atelet.AteomHerder.Checkpoint:input_type -> atelet.CheckpointRequest - 11, // 14: atelet.AteomHerder.Restore:input_type -> atelet.RestoreRequest - 8, // 15: atelet.AteomHerder.Run:output_type -> atelet.RunResponse - 10, // 16: atelet.AteomHerder.Checkpoint:output_type -> atelet.CheckpointResponse - 12, // 17: atelet.AteomHerder.Restore:output_type -> atelet.RestoreResponse - 15, // [15:18] is the sub-list for method output_type - 12, // [12:15] is the sub-list for method input_type - 12, // [12:12] is the sub-list for extension type_name - 12, // [12:12] is the sub-list for extension extendee - 0, // [0:12] is the sub-list for field type_name + 8, // 7: atelet.Container.env:type_name -> atelet.EnvEntry + 7, // 8: atelet.Container.gpu:type_name -> atelet.GpuSpec + 4, // 9: atelet.CheckpointRequest.runsc:type_name -> atelet.RunscConfig + 5, // 10: atelet.CheckpointRequest.spec:type_name -> atelet.WorkloadSpec + 4, // 11: atelet.RestoreRequest.runsc:type_name -> atelet.RunscConfig + 5, // 12: atelet.RestoreRequest.spec:type_name -> atelet.WorkloadSpec + 0, // 13: atelet.AteomHerder.Run:input_type -> atelet.RunRequest + 10, // 14: atelet.AteomHerder.Checkpoint:input_type -> atelet.CheckpointRequest + 12, // 15: atelet.AteomHerder.Restore:input_type -> atelet.RestoreRequest + 9, // 16: atelet.AteomHerder.Run:output_type -> atelet.RunResponse + 11, // 17: atelet.AteomHerder.Checkpoint:output_type -> atelet.CheckpointResponse + 13, // 18: atelet.AteomHerder.Restore:output_type -> atelet.RestoreResponse + 16, // [16:19] is the sub-list for method output_type + 13, // [13:16] is the sub-list for method input_type + 13, // [13:13] is the sub-list for extension type_name + 13, // [13:13] is the sub-list for extension extendee + 0, // [0:13] is the sub-list for field type_name } func init() { file_atelet_proto_init() } @@ -955,7 +1040,7 @@ func file_atelet_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_atelet_proto_rawDesc), len(file_atelet_proto_rawDesc)), NumEnums: 0, - NumMessages: 13, + NumMessages: 14, NumExtensions: 0, NumServices: 1, }, diff --git a/internal/proto/ateletpb/atelet.proto b/internal/proto/ateletpb/atelet.proto index 8a356db..78ac563 100644 --- a/internal/proto/ateletpb/atelet.proto +++ b/internal/proto/ateletpb/atelet.proto @@ -81,6 +81,15 @@ message Container { string image = 2; repeated string command = 3; repeated EnvEntry env = 4; + GpuSpec gpu = 5; +} + +// GpuSpec mirrors v1alpha1.GPUResource on the ActorTemplate CRD. +message GpuSpec { + int32 count = 1; + string device = 2; + repeated string driver_capabilities = 3; + string driver_version = 4; } message EnvEntry { diff --git a/internal/proto/ateompb/ateom.pb.go b/internal/proto/ateompb/ateom.pb.go index 2faced0..dea05cd 100644 --- a/internal/proto/ateompb/ateom.pb.go +++ b/internal/proto/ateompb/ateom.pb.go @@ -159,6 +159,7 @@ func (x *WorkloadSpec) GetContainers() []*Container { type Container struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Gpu *GpuSpec `protobuf:"bytes,2,opt,name=gpu,proto3" json:"gpu,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -200,6 +201,82 @@ func (x *Container) GetName() string { return "" } +func (x *Container) GetGpu() *GpuSpec { + if x != nil { + return x.Gpu + } + return nil +} + +// GpuSpec mirrors ateletpb.GpuSpec. +type GpuSpec struct { + state protoimpl.MessageState `protogen:"open.v1"` + Count int32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` + Device string `protobuf:"bytes,2,opt,name=device,proto3" json:"device,omitempty"` + DriverCapabilities []string `protobuf:"bytes,3,rep,name=driver_capabilities,json=driverCapabilities,proto3" json:"driver_capabilities,omitempty"` + DriverVersion string `protobuf:"bytes,4,opt,name=driver_version,json=driverVersion,proto3" json:"driver_version,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GpuSpec) Reset() { + *x = GpuSpec{} + mi := &file_ateom_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GpuSpec) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GpuSpec) ProtoMessage() {} + +func (x *GpuSpec) ProtoReflect() protoreflect.Message { + mi := &file_ateom_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GpuSpec.ProtoReflect.Descriptor instead. +func (*GpuSpec) Descriptor() ([]byte, []int) { + return file_ateom_proto_rawDescGZIP(), []int{3} +} + +func (x *GpuSpec) GetCount() int32 { + if x != nil { + return x.Count + } + return 0 +} + +func (x *GpuSpec) GetDevice() string { + if x != nil { + return x.Device + } + return "" +} + +func (x *GpuSpec) GetDriverCapabilities() []string { + if x != nil { + return x.DriverCapabilities + } + return nil +} + +func (x *GpuSpec) GetDriverVersion() string { + if x != nil { + return x.DriverVersion + } + return "" +} + type RunWorkloadResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -208,7 +285,7 @@ type RunWorkloadResponse struct { func (x *RunWorkloadResponse) Reset() { *x = RunWorkloadResponse{} - mi := &file_ateom_proto_msgTypes[3] + mi := &file_ateom_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -220,7 +297,7 @@ func (x *RunWorkloadResponse) String() string { func (*RunWorkloadResponse) ProtoMessage() {} func (x *RunWorkloadResponse) ProtoReflect() protoreflect.Message { - mi := &file_ateom_proto_msgTypes[3] + mi := &file_ateom_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -233,7 +310,7 @@ func (x *RunWorkloadResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RunWorkloadResponse.ProtoReflect.Descriptor instead. func (*RunWorkloadResponse) Descriptor() ([]byte, []int) { - return file_ateom_proto_rawDescGZIP(), []int{3} + return file_ateom_proto_rawDescGZIP(), []int{4} } type CheckpointWorkloadRequest struct { @@ -258,7 +335,7 @@ type CheckpointWorkloadRequest struct { func (x *CheckpointWorkloadRequest) Reset() { *x = CheckpointWorkloadRequest{} - mi := &file_ateom_proto_msgTypes[4] + mi := &file_ateom_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -270,7 +347,7 @@ func (x *CheckpointWorkloadRequest) String() string { func (*CheckpointWorkloadRequest) ProtoMessage() {} func (x *CheckpointWorkloadRequest) ProtoReflect() protoreflect.Message { - mi := &file_ateom_proto_msgTypes[4] + mi := &file_ateom_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -283,7 +360,7 @@ func (x *CheckpointWorkloadRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CheckpointWorkloadRequest.ProtoReflect.Descriptor instead. func (*CheckpointWorkloadRequest) Descriptor() ([]byte, []int) { - return file_ateom_proto_rawDescGZIP(), []int{4} + return file_ateom_proto_rawDescGZIP(), []int{5} } func (x *CheckpointWorkloadRequest) GetActorTemplateNamespace() string { @@ -336,7 +413,7 @@ type CheckpointWorkloadResponse struct { func (x *CheckpointWorkloadResponse) Reset() { *x = CheckpointWorkloadResponse{} - mi := &file_ateom_proto_msgTypes[5] + mi := &file_ateom_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -348,7 +425,7 @@ func (x *CheckpointWorkloadResponse) String() string { func (*CheckpointWorkloadResponse) ProtoMessage() {} func (x *CheckpointWorkloadResponse) ProtoReflect() protoreflect.Message { - mi := &file_ateom_proto_msgTypes[5] + mi := &file_ateom_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -361,7 +438,7 @@ func (x *CheckpointWorkloadResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CheckpointWorkloadResponse.ProtoReflect.Descriptor instead. func (*CheckpointWorkloadResponse) Descriptor() ([]byte, []int) { - return file_ateom_proto_rawDescGZIP(), []int{5} + return file_ateom_proto_rawDescGZIP(), []int{6} } type RestoreWorkloadRequest struct { @@ -379,7 +456,7 @@ type RestoreWorkloadRequest struct { func (x *RestoreWorkloadRequest) Reset() { *x = RestoreWorkloadRequest{} - mi := &file_ateom_proto_msgTypes[6] + mi := &file_ateom_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -391,7 +468,7 @@ func (x *RestoreWorkloadRequest) String() string { func (*RestoreWorkloadRequest) ProtoMessage() {} func (x *RestoreWorkloadRequest) ProtoReflect() protoreflect.Message { - mi := &file_ateom_proto_msgTypes[6] + mi := &file_ateom_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -404,7 +481,7 @@ func (x *RestoreWorkloadRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RestoreWorkloadRequest.ProtoReflect.Descriptor instead. func (*RestoreWorkloadRequest) Descriptor() ([]byte, []int) { - return file_ateom_proto_rawDescGZIP(), []int{6} + return file_ateom_proto_rawDescGZIP(), []int{7} } func (x *RestoreWorkloadRequest) GetActorTemplateNamespace() string { @@ -457,7 +534,7 @@ type RestoreWorkloadResponse struct { func (x *RestoreWorkloadResponse) Reset() { *x = RestoreWorkloadResponse{} - mi := &file_ateom_proto_msgTypes[7] + mi := &file_ateom_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -469,7 +546,7 @@ func (x *RestoreWorkloadResponse) String() string { func (*RestoreWorkloadResponse) ProtoMessage() {} func (x *RestoreWorkloadResponse) ProtoReflect() protoreflect.Message { - mi := &file_ateom_proto_msgTypes[7] + mi := &file_ateom_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -482,7 +559,7 @@ func (x *RestoreWorkloadResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RestoreWorkloadResponse.ProtoReflect.Descriptor instead. func (*RestoreWorkloadResponse) Descriptor() ([]byte, []int) { - return file_ateom_proto_rawDescGZIP(), []int{7} + return file_ateom_proto_rawDescGZIP(), []int{8} } var File_ateom_proto protoreflect.FileDescriptor @@ -500,9 +577,15 @@ const file_ateom_proto_rawDesc = "" + "\fWorkloadSpec\x120\n" + "\n" + "containers\x18\x01 \x03(\v2\x10.ateom.ContainerR\n" + - "containers\"\x1f\n" + + "containers\"A\n" + "\tContainer\x12\x12\n" + - "\x04name\x18\x01 \x01(\tR\x04name\"\x15\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12 \n" + + "\x03gpu\x18\x02 \x01(\v2\x0e.ateom.GpuSpecR\x03gpu\"\x8f\x01\n" + + "\aGpuSpec\x12\x14\n" + + "\x05count\x18\x01 \x01(\x05R\x05count\x12\x16\n" + + "\x06device\x18\x02 \x01(\tR\x06device\x12/\n" + + "\x13driver_capabilities\x18\x03 \x03(\tR\x12driverCapabilities\x12%\n" + + "\x0edriver_version\x18\x04 \x01(\tR\rdriverVersion\"\x15\n" + "\x13RunWorkloadResponse\"\x98\x02\n" + "\x19CheckpointWorkloadRequest\x128\n" + "\x18actor_template_namespace\x18\x01 \x01(\tR\x16actorTemplateNamespace\x12.\n" + @@ -539,33 +622,35 @@ func file_ateom_proto_rawDescGZIP() []byte { return file_ateom_proto_rawDescData } -var file_ateom_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_ateom_proto_msgTypes = make([]protoimpl.MessageInfo, 9) var file_ateom_proto_goTypes = []any{ (*RunWorkloadRequest)(nil), // 0: ateom.RunWorkloadRequest (*WorkloadSpec)(nil), // 1: ateom.WorkloadSpec (*Container)(nil), // 2: ateom.Container - (*RunWorkloadResponse)(nil), // 3: ateom.RunWorkloadResponse - (*CheckpointWorkloadRequest)(nil), // 4: ateom.CheckpointWorkloadRequest - (*CheckpointWorkloadResponse)(nil), // 5: ateom.CheckpointWorkloadResponse - (*RestoreWorkloadRequest)(nil), // 6: ateom.RestoreWorkloadRequest - (*RestoreWorkloadResponse)(nil), // 7: ateom.RestoreWorkloadResponse + (*GpuSpec)(nil), // 3: ateom.GpuSpec + (*RunWorkloadResponse)(nil), // 4: ateom.RunWorkloadResponse + (*CheckpointWorkloadRequest)(nil), // 5: ateom.CheckpointWorkloadRequest + (*CheckpointWorkloadResponse)(nil), // 6: ateom.CheckpointWorkloadResponse + (*RestoreWorkloadRequest)(nil), // 7: ateom.RestoreWorkloadRequest + (*RestoreWorkloadResponse)(nil), // 8: ateom.RestoreWorkloadResponse } var file_ateom_proto_depIdxs = []int32{ 1, // 0: ateom.RunWorkloadRequest.spec:type_name -> ateom.WorkloadSpec 2, // 1: ateom.WorkloadSpec.containers:type_name -> ateom.Container - 1, // 2: ateom.CheckpointWorkloadRequest.spec:type_name -> ateom.WorkloadSpec - 1, // 3: ateom.RestoreWorkloadRequest.spec:type_name -> ateom.WorkloadSpec - 0, // 4: ateom.Ateom.RunWorkload:input_type -> ateom.RunWorkloadRequest - 4, // 5: ateom.Ateom.CheckpointWorkload:input_type -> ateom.CheckpointWorkloadRequest - 6, // 6: ateom.Ateom.RestoreWorkload:input_type -> ateom.RestoreWorkloadRequest - 3, // 7: ateom.Ateom.RunWorkload:output_type -> ateom.RunWorkloadResponse - 5, // 8: ateom.Ateom.CheckpointWorkload:output_type -> ateom.CheckpointWorkloadResponse - 7, // 9: ateom.Ateom.RestoreWorkload:output_type -> ateom.RestoreWorkloadResponse - 7, // [7:10] is the sub-list for method output_type - 4, // [4:7] is the sub-list for method input_type - 4, // [4:4] is the sub-list for extension type_name - 4, // [4:4] is the sub-list for extension extendee - 0, // [0:4] is the sub-list for field type_name + 3, // 2: ateom.Container.gpu:type_name -> ateom.GpuSpec + 1, // 3: ateom.CheckpointWorkloadRequest.spec:type_name -> ateom.WorkloadSpec + 1, // 4: ateom.RestoreWorkloadRequest.spec:type_name -> ateom.WorkloadSpec + 0, // 5: ateom.Ateom.RunWorkload:input_type -> ateom.RunWorkloadRequest + 5, // 6: ateom.Ateom.CheckpointWorkload:input_type -> ateom.CheckpointWorkloadRequest + 7, // 7: ateom.Ateom.RestoreWorkload:input_type -> ateom.RestoreWorkloadRequest + 4, // 8: ateom.Ateom.RunWorkload:output_type -> ateom.RunWorkloadResponse + 6, // 9: ateom.Ateom.CheckpointWorkload:output_type -> ateom.CheckpointWorkloadResponse + 8, // 10: ateom.Ateom.RestoreWorkload:output_type -> ateom.RestoreWorkloadResponse + 8, // [8:11] is the sub-list for method output_type + 5, // [5:8] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 5, // [5:5] is the sub-list for extension extendee + 0, // [0:5] is the sub-list for field type_name } func init() { file_ateom_proto_init() } @@ -579,7 +664,7 @@ func file_ateom_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_ateom_proto_rawDesc), len(file_ateom_proto_rawDesc)), NumEnums: 0, - NumMessages: 8, + NumMessages: 9, NumExtensions: 0, NumServices: 1, }, diff --git a/internal/proto/ateompb/ateom.proto b/internal/proto/ateompb/ateom.proto index 629c4ca..c078d53 100644 --- a/internal/proto/ateompb/ateom.proto +++ b/internal/proto/ateompb/ateom.proto @@ -63,7 +63,16 @@ message WorkloadSpec { } message Container { - string name = 1; + string name = 1; + GpuSpec gpu = 2; +} + +// GpuSpec mirrors ateletpb.GpuSpec. +message GpuSpec { + int32 count = 1; + string device = 2; + repeated string driver_capabilities = 3; + string driver_version = 4; } message RunWorkloadResponse { diff --git a/manifests/ate-install/generated/ate.dev_actortemplates.yaml b/manifests/ate-install/generated/ate.dev_actortemplates.yaml index c3fd0ae..1c508ff 100644 --- a/manifests/ate-install/generated/ate.dev_actortemplates.yaml +++ b/manifests/ate-install/generated/ate.dev_actortemplates.yaml @@ -269,6 +269,37 @@ spec: - containerPort type: object type: array + resources: + description: ContainerResources is a per-container resource + request block. + properties: + gpu: + description: GPUResource is the NVIDIA GPU passthrough request. + properties: + count: + default: 1 + format: int32 + maximum: 1 + minimum: 1 + type: integer + device: + description: Device is a CDI device id (e.g. "nvidia.com/gpu=0") + or "all". + type: string + driverCapabilities: + description: |- + DriverCapabilities for runsc --nvproxy-allowed-driver-capabilities. + Defaults to ["compute","utility"]. + items: + type: string + type: array + x-kubernetes-list-type: atomic + driverVersion: + description: DriverVersion pins runsc's nvproxy driver + ABI. Empty uses host's. + type: string + type: object + type: object required: - name type: object diff --git a/pkg/api/v1alpha1/actortemplate_types.go b/pkg/api/v1alpha1/actortemplate_types.go index 5106fb5..4de766e 100644 --- a/pkg/api/v1alpha1/actortemplate_types.go +++ b/pkg/api/v1alpha1/actortemplate_types.go @@ -49,6 +49,38 @@ type Container struct { // Environment variables to set in the worker replicas. Env []corev1.EnvVar `json:"env,omitempty"` + + // +optional + Resources *ContainerResources `json:"resources,omitempty"` +} + +// ContainerResources is a per-container resource request block. +type ContainerResources struct { + // +optional + GPU *GPUResource `json:"gpu,omitempty"` +} + +// GPUResource is the NVIDIA GPU passthrough request. +type GPUResource struct { + // +optional + // +kubebuilder:default=1 + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Maximum=1 + Count int32 `json:"count,omitempty"` + + // Device is a CDI device id (e.g. "nvidia.com/gpu=0") or "all". + // +optional + Device string `json:"device,omitempty"` + + // DriverCapabilities for runsc --nvproxy-allowed-driver-capabilities. + // Defaults to ["compute","utility"]. + // +optional + // +listType=atomic + DriverCapabilities []string `json:"driverCapabilities,omitempty"` + + // DriverVersion pins runsc's nvproxy driver ABI. Empty uses host's. + // +optional + DriverVersion string `json:"driverVersion,omitempty"` } type SnapshotsConfig struct { diff --git a/pkg/api/v1alpha1/zz_generated.deepcopy.go b/pkg/api/v1alpha1/zz_generated.deepcopy.go index a66bb6a..329647a 100644 --- a/pkg/api/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/api/v1alpha1/zz_generated.deepcopy.go @@ -171,6 +171,11 @@ func (in *Container) DeepCopyInto(out *Container) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.Resources != nil { + in, out := &in.Resources, &out.Resources + *out = new(ContainerResources) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Container. @@ -183,6 +188,26 @@ func (in *Container) DeepCopy() *Container { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ContainerResources) DeepCopyInto(out *ContainerResources) { + *out = *in + if in.GPU != nil { + in, out := &in.GPU, &out.GPU + *out = new(GPUResource) + (*in).DeepCopyInto(*out) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ContainerResources. +func (in *ContainerResources) DeepCopy() *ContainerResources { + if in == nil { + return nil + } + out := new(ContainerResources) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *GCPAuthenticationConfig) DeepCopyInto(out *GCPAuthenticationConfig) { *out = *in @@ -198,6 +223,26 @@ func (in *GCPAuthenticationConfig) DeepCopy() *GCPAuthenticationConfig { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *GPUResource) DeepCopyInto(out *GPUResource) { + *out = *in + if in.DriverCapabilities != nil { + in, out := &in.DriverCapabilities, &out.DriverCapabilities + *out = make([]string, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new GPUResource. +func (in *GPUResource) DeepCopy() *GPUResource { + if in == nil { + return nil + } + out := new(GPUResource) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *RunscConfig) DeepCopyInto(out *RunscConfig) { *out = *in