From d16cded91c007bc020235b27d9a7659fcbf02167 Mon Sep 17 00:00:00 2001 From: Rishi Jat Date: Sun, 10 May 2026 05:11:03 +0530 Subject: [PATCH 1/4] feat(service): add node-level LayerCache with rebuild support Introduces a thread-safe LayerCache that provides: - Singleflight concurrency deduplication per layer digest. - Node-level weighted semaphore for fetch flow control. - Restart-safe rebuild mechanisms via layer_digests.json. Signed-off-by: Rishi Jat --- pkg/service/layer_cache.go | 253 +++++++++++++ pkg/service/layer_cache_test.go | 627 ++++++++++++++++++++++++++++++++ pkg/service/layer_metadata.go | 46 +++ 3 files changed, 926 insertions(+) create mode 100644 pkg/service/layer_cache.go create mode 100644 pkg/service/layer_cache_test.go create mode 100644 pkg/service/layer_metadata.go diff --git a/pkg/service/layer_cache.go b/pkg/service/layer_cache.go new file mode 100644 index 0000000..1411334 --- /dev/null +++ b/pkg/service/layer_cache.go @@ -0,0 +1,253 @@ +package service + +import ( + "os" + "path/filepath" + "strings" + "sync" + + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/logger" + "github.com/modelpack/model-csi-driver/pkg/status" + "github.com/opencontainers/go-digest" + "golang.org/x/sync/semaphore" + "golang.org/x/sync/singleflight" +) + +// defaultMaxConcurrentLayers is the default maximum number of layers that +// can be pulled concurrently across all volumes on a node. +const defaultMaxConcurrentLayers int64 = 8 + +// LayerCache maintains an in-memory mapping of layer digest → file paths +// on disk, enabling layer-level deduplication via hardlinks. +// +// Thread-safety: all methods are safe for concurrent use. +type LayerCache struct { + mu sync.RWMutex + // layers maps a layer digest to all known on-disk file paths containing + // that layer's content. Multiple volumes can reference the same layer. + layers map[digest.Digest][]string + + // sfGroup deduplicates concurrent pull requests for the same layer digest. + // The key is the digest string. + sfGroup singleflight.Group + + // sem controls the maximum number of concurrently in-flight layer pulls + // at the node level, preventing uncontrolled network and disk IO fan-out. + sem *semaphore.Weighted +} + +// NewLayerCache creates a new empty LayerCache with the given concurrency limit. +func NewLayerCache(maxConcurrentLayers int64) *LayerCache { + if maxConcurrentLayers <= 0 { + maxConcurrentLayers = defaultMaxConcurrentLayers + } + return &LayerCache{ + layers: make(map[digest.Digest][]string), + sem: semaphore.NewWeighted(maxConcurrentLayers), + } +} + +// Semaphore returns the weighted semaphore for node-level flow control. +func (lc *LayerCache) Semaphore() *semaphore.Weighted { + return lc.sem +} + +// SflightGroup returns the singleflight group for layer-level dedup. +func (lc *LayerCache) SflightGroup() *singleflight.Group { + return &lc.sfGroup +} + +// Register adds a file path for a given layer digest. +// If the path is already registered for the digest, this is a no-op. +func (lc *LayerCache) Register(d digest.Digest, path string) { + lc.mu.Lock() + defer lc.mu.Unlock() + + paths := lc.layers[d] + for _, p := range paths { + if p == path { + return // already registered + } + } + lc.layers[d] = append(paths, path) +} + +// Lookup returns an existing, valid file path for the given layer digest. +// It verifies the file still exists on disk before returning it. +// Returns ("", false) if no valid path is found. +func (lc *LayerCache) Lookup(d digest.Digest) (string, bool) { + lc.mu.RLock() + paths := lc.layers[d] + lc.mu.RUnlock() + + for _, p := range paths { + if _, err := os.Stat(p); err == nil { + return p, true + } + } + return "", false +} + +// Remove removes a specific path from all digest entries. +func (lc *LayerCache) Remove(path string) { + lc.mu.Lock() + defer lc.mu.Unlock() + + for d, paths := range lc.layers { + filtered := paths[:0] + for _, p := range paths { + if p != path { + filtered = append(filtered, p) + } + } + if len(filtered) == 0 { + delete(lc.layers, d) + } else { + lc.layers[d] = filtered + } + } +} + +// RemoveByPrefix removes all paths that have the given prefix from all digest +// entries. This is used during volume cleanup to evict all layer references +// under a volume directory. +func (lc *LayerCache) RemoveByPrefix(prefix string) { + lc.mu.Lock() + defer lc.mu.Unlock() + + for d, paths := range lc.layers { + filtered := paths[:0] + for _, p := range paths { + if !strings.HasPrefix(p, prefix) { + filtered = append(filtered, p) + } + } + if len(filtered) == 0 { + delete(lc.layers, d) + } else { + lc.layers[d] = filtered + } + } +} + +// Len returns the number of unique digests tracked. +func (lc *LayerCache) Len() int { + lc.mu.RLock() + defer lc.mu.RUnlock() + return len(lc.layers) +} + +// PathCount returns the total number of paths tracked across all digests. +func (lc *LayerCache) PathCount() int { + lc.mu.RLock() + defer lc.mu.RUnlock() + count := 0 + for _, paths := range lc.layers { + count += len(paths) + } + return count +} + +// Snapshot returns a copy of the current cache state for testing/debugging. +func (lc *LayerCache) Snapshot() map[digest.Digest][]string { + lc.mu.RLock() + defer lc.mu.RUnlock() + snapshot := make(map[digest.Digest][]string, len(lc.layers)) + for d, paths := range lc.layers { + cp := make([]string, len(paths)) + copy(cp, paths) + snapshot[d] = cp + } + return snapshot +} + +// Rebuild scans existing volume directories and re-populates the cache +// from disk. This is called on startup to restore layer cache state after +// a daemon restart or node reboot. +// +// For each volume with a successful pull status, we scan the model directory +// and register all files. The digest mapping is restored from the status +// metadata stored during previous pulls. +func (lc *LayerCache) Rebuild(cfg *config.RawConfig, sm *status.StatusManager) { + volumesDir := cfg.GetVolumesDir() + volumeDirs, err := os.ReadDir(volumesDir) + if err != nil { + if !os.IsNotExist(err) { + logger.Logger().WithError(err).Errorf("layer cache rebuild: read volumes dir %s", volumesDir) + } + return + } + + registered := 0 + for _, volumeDir := range volumeDirs { + if !volumeDir.IsDir() { + continue + } + volumeName := volumeDir.Name() + + if isStaticVolume(volumeName) { + n := lc.rebuildVolume(cfg.GetVolumeDir(volumeName), sm) + registered += n + } + + if isDynamicVolume(volumeName) { + modelsDir := cfg.GetModelsDirForDynamic(volumeName) + modelDirs, err := os.ReadDir(modelsDir) + if err != nil { + continue + } + for _, modelDir := range modelDirs { + if !modelDir.IsDir() { + continue + } + mountID := modelDir.Name() + n := lc.rebuildVolume(cfg.GetMountIDDirForDynamic(volumeName, mountID), sm) + registered += n + } + } + } + + logger.Logger().Infof("layer cache rebuild complete: %d layer-path entries registered, %d unique digests", + registered, lc.Len()) +} + +// rebuildVolume scans a single volume's model directory and registers +// any files found. It reads the layer digest metadata stored alongside +// the volume status to map files back to their digests. +func (lc *LayerCache) rebuildVolume(volumeDir string, sm *status.StatusManager) int { + statusPath := filepath.Join(volumeDir, "status.json") + s, err := sm.Get(statusPath) + if err != nil { + return 0 + } + + // Only rebuild from successfully pulled volumes. + if s.State != status.StatePullSucceeded && s.State != status.StateMounted { + return 0 + } + + // Read the layer digest metadata file. + metadataPath := filepath.Join(volumeDir, "layer_digests.json") + metadata, err := loadLayerMetadata(metadataPath) + if err != nil { + return 0 + } + + registered := 0 + modelDir := filepath.Join(volumeDir, "model") + for _, entry := range metadata { + filePath := filepath.Join(modelDir, entry.FilePath) + if _, err := os.Stat(filePath); err != nil { + continue // file no longer exists, skip + } + d, err := digest.Parse(entry.Digest) + if err != nil { + continue + } + lc.Register(d, filePath) + registered++ + } + + return registered +} diff --git a/pkg/service/layer_cache_test.go b/pkg/service/layer_cache_test.go new file mode 100644 index 0000000..a6f4b31 --- /dev/null +++ b/pkg/service/layer_cache_test.go @@ -0,0 +1,627 @@ +package service + +import ( + "context" + "os" + "path/filepath" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/status" + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/require" +) + +// ─── LayerCache basic ops ───────────────────────────────────────────────────── + +func TestLayerCache_RegisterAndLookup(t *testing.T) { + lc := NewLayerCache(8) + + d := digest.FromString("layer-content-1") + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "layer1.bin") + require.NoError(t, os.WriteFile(path, []byte("data"), 0644)) + + lc.Register(d, path) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, path, found) +} + +func TestLayerCache_LookupMissing(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("nonexistent") + _, ok := lc.Lookup(d) + require.False(t, ok) +} + +func TestLayerCache_LookupStaleFile(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("stale-content") + path := filepath.Join(t.TempDir(), "deleted.bin") + + // Register a path that doesn't exist on disk. + lc.Register(d, path) + + _, ok := lc.Lookup(d) + require.False(t, ok, "should not find stale path") +} + +func TestLayerCache_RegisterIdempotent(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("dup-content") + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "layer.bin") + require.NoError(t, os.WriteFile(path, []byte("data"), 0644)) + + lc.Register(d, path) + lc.Register(d, path) + lc.Register(d, path) + + require.Equal(t, 1, lc.PathCount()) +} + +func TestLayerCache_MultiplePaths(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("shared-content") + tmpDir := t.TempDir() + + path1 := filepath.Join(tmpDir, "vol1", "layer.bin") + path2 := filepath.Join(tmpDir, "vol2", "layer.bin") + require.NoError(t, os.MkdirAll(filepath.Dir(path1), 0755)) + require.NoError(t, os.MkdirAll(filepath.Dir(path2), 0755)) + require.NoError(t, os.WriteFile(path1, []byte("data"), 0644)) + require.NoError(t, os.WriteFile(path2, []byte("data"), 0644)) + + lc.Register(d, path1) + lc.Register(d, path2) + + require.Equal(t, 2, lc.PathCount()) + require.Equal(t, 1, lc.Len()) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Contains(t, []string{path1, path2}, found) +} + +// ─── LayerCache Remove ──────────────────────────────────────────────────────── + +func TestLayerCache_Remove(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("remove-test") + tmpDir := t.TempDir() + + path1 := filepath.Join(tmpDir, "layer1.bin") + path2 := filepath.Join(tmpDir, "layer2.bin") + require.NoError(t, os.WriteFile(path1, []byte("data"), 0644)) + require.NoError(t, os.WriteFile(path2, []byte("data"), 0644)) + + lc.Register(d, path1) + lc.Register(d, path2) + + lc.Remove(path1) + require.Equal(t, 1, lc.PathCount()) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, path2, found) +} + +func TestLayerCache_RemoveLastPath(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("remove-last") + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "only.bin") + require.NoError(t, os.WriteFile(path, []byte("data"), 0644)) + + lc.Register(d, path) + lc.Remove(path) + + require.Equal(t, 0, lc.Len()) + require.Equal(t, 0, lc.PathCount()) +} + +// ─── LayerCache RemoveByPrefix ──────────────────────────────────────────────── + +func TestLayerCache_RemoveByPrefix(t *testing.T) { + lc := NewLayerCache(8) + tmpDir := t.TempDir() + + // Create two volumes with layers. + d1 := digest.FromString("layer-a") + d2 := digest.FromString("layer-b") + d3 := digest.FromString("layer-c") + + vol1Path := filepath.Join(tmpDir, "vol1", "model", "a.bin") + vol1PathB := filepath.Join(tmpDir, "vol1", "model", "b.bin") + vol2Path := filepath.Join(tmpDir, "vol2", "model", "c.bin") + + for _, p := range []string{vol1Path, vol1PathB, vol2Path} { + require.NoError(t, os.MkdirAll(filepath.Dir(p), 0755)) + require.NoError(t, os.WriteFile(p, []byte("data"), 0644)) + } + + lc.Register(d1, vol1Path) + lc.Register(d2, vol1PathB) + lc.Register(d3, vol2Path) + + require.Equal(t, 3, lc.Len()) + + // Remove all entries under vol1. + lc.RemoveByPrefix(filepath.Join(tmpDir, "vol1")) + + require.Equal(t, 1, lc.Len()) + require.Equal(t, 1, lc.PathCount()) + + // vol2 entry should still be there. + found, ok := lc.Lookup(d3) + require.True(t, ok) + require.Equal(t, vol2Path, found) +} + +// ─── LayerCache Snapshot ────────────────────────────────────────────────────── + +func TestLayerCache_Snapshot(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("snap-content") + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "snap.bin") + require.NoError(t, os.WriteFile(path, []byte("data"), 0644)) + + lc.Register(d, path) + + snap := lc.Snapshot() + require.Len(t, snap, 1) + require.Equal(t, []string{path}, snap[d]) + + // Modifying snapshot should not affect the cache. + snap[d] = append(snap[d], "extra") + require.Equal(t, 1, lc.PathCount()) +} + +// ─── LayerCache Concurrency ─────────────────────────────────────────────────── + +func TestLayerCache_ConcurrentRegisterAndLookup(t *testing.T) { + lc := NewLayerCache(8) + tmpDir := t.TempDir() + + const n = 50 + digests := make([]digest.Digest, n) + paths := make([]string, n) + for i := 0; i < n; i++ { + digests[i] = digest.FromString(string(rune(i))) + paths[i] = filepath.Join(tmpDir, string(rune('a'+i))) + require.NoError(t, os.WriteFile(paths[i], []byte("data"), 0644)) + } + + var wg sync.WaitGroup + // Concurrently register. + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + lc.Register(digests[i], paths[i]) + }(i) + } + wg.Wait() + + // Concurrently lookup. + var found atomic.Int32 + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + if _, ok := lc.Lookup(digests[i]); ok { + found.Add(1) + } + }(i) + } + wg.Wait() + + require.Equal(t, int32(n), found.Load()) +} + +func TestLayerCache_ConcurrentSameLayerRegister(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("same-layer") + tmpDir := t.TempDir() + + const n = 20 + paths := make([]string, n) + for i := 0; i < n; i++ { + paths[i] = filepath.Join(tmpDir, string(rune('a'+i))) + require.NoError(t, os.WriteFile(paths[i], []byte("data"), 0644)) + } + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + lc.Register(d, paths[i]) + }(i) + } + wg.Wait() + + // All n paths should be registered under one digest. + require.Equal(t, 1, lc.Len()) + require.Equal(t, n, lc.PathCount()) +} + +// ─── LayerCache Rebuild ─────────────────────────────────────────────────────── + +func TestLayerCache_Rebuild(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + // Set up a static volume with a successful pull. + volumeName := "pvc-rebuild-test" + volumeDir := cfg.Get().GetVolumeDir(volumeName) + modelDir := cfg.Get().GetModelDir(volumeName) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + // Write status. + statusPath := filepath.Join(volumeDir, "status.json") + _, err = sm.Set(statusPath, status.Status{ + VolumeName: volumeName, + Reference: "registry/model:v1", + State: status.StatePullSucceeded, + }) + require.NoError(t, err) + + // Write a model file. + layerFile := filepath.Join(modelDir, "weights.bin") + require.NoError(t, os.WriteFile(layerFile, []byte("layer-data"), 0644)) + + // Write layer metadata. + d := digest.FromString("weights-content") + metadataPath := filepath.Join(volumeDir, "layer_digests.json") + require.NoError(t, saveLayerMetadata(metadataPath, []LayerMetadataEntry{ + {Digest: d.String(), FilePath: "weights.bin", Size: 10}, + })) + + // Rebuild. + lc := NewLayerCache(8) + lc.Rebuild(cfg.Get(), sm) + + require.Equal(t, 1, lc.Len()) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, layerFile, found) +} + +func TestLayerCache_RebuildSkipsFailedPull(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + volumeName := "pvc-failed-test" + volumeDir := cfg.Get().GetVolumeDir(volumeName) + modelDir := cfg.Get().GetModelDir(volumeName) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + statusPath := filepath.Join(volumeDir, "status.json") + _, err = sm.Set(statusPath, status.Status{ + VolumeName: volumeName, + Reference: "registry/model:v1", + State: status.StatePullFailed, + }) + require.NoError(t, err) + + lc := NewLayerCache(8) + lc.Rebuild(cfg.Get(), sm) + + require.Equal(t, 0, lc.Len(), "should not rebuild from failed pull") +} + +func TestLayerCache_RebuildDynamicVolume(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + volumeName := "csi-dyn-rebuild" + mountID := "mount-1" + mountDir := cfg.Get().GetMountIDDirForDynamic(volumeName, mountID) + modelDir := cfg.Get().GetModelDirForDynamic(volumeName, mountID) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + statusPath := filepath.Join(mountDir, "status.json") + _, err = sm.Set(statusPath, status.Status{ + Reference: "registry/model:dyn", + State: status.StateMounted, + }) + require.NoError(t, err) + + layerFile := filepath.Join(modelDir, "model.bin") + require.NoError(t, os.WriteFile(layerFile, []byte("dyn-data"), 0644)) + + d := digest.FromString("dyn-content") + metadataPath := filepath.Join(mountDir, "layer_digests.json") + require.NoError(t, saveLayerMetadata(metadataPath, []LayerMetadataEntry{ + {Digest: d.String(), FilePath: "model.bin", Size: 8}, + })) + + lc := NewLayerCache(8) + lc.Rebuild(cfg.Get(), sm) + + require.Equal(t, 1, lc.Len()) + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, layerFile, found) +} + +func TestLayerCache_RebuildSkipsMissingFile(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + volumeName := "pvc-missing-file" + volumeDir := cfg.Get().GetVolumeDir(volumeName) + modelDir := cfg.Get().GetModelDir(volumeName) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + statusPath := filepath.Join(volumeDir, "status.json") + _, err = sm.Set(statusPath, status.Status{ + VolumeName: volumeName, + Reference: "registry/model:v1", + State: status.StatePullSucceeded, + }) + require.NoError(t, err) + + d := digest.FromString("missing-file-content") + metadataPath := filepath.Join(volumeDir, "layer_digests.json") + require.NoError(t, saveLayerMetadata(metadataPath, []LayerMetadataEntry{ + {Digest: d.String(), FilePath: "does_not_exist.bin", Size: 10}, + })) + + lc := NewLayerCache(8) + lc.Rebuild(cfg.Get(), sm) + + require.Equal(t, 0, lc.Len(), "should skip missing files during rebuild") +} + +// ─── LayerCache Cleanup Integration ─────────────────────────────────────────── + +func TestLayerCache_CleanupOnDeleteModel(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + lc := NewLayerCache(8) + + worker, err := NewWorkerWithLayerCache(cfg, sm, lc) + require.NoError(t, err) + + // Simulate a pulled volume with cached layers. + volumeName := "pvc-cleanup-test" + modelDir := cfg.Get().GetModelDir(volumeName) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + d := digest.FromString("cleanup-content") + layerFile := filepath.Join(modelDir, "weights.bin") + require.NoError(t, os.WriteFile(layerFile, []byte("data"), 0644)) + lc.Register(d, layerFile) + + require.Equal(t, 1, lc.Len()) + + // Delete the volume. + err = worker.DeleteModel(context.Background(), true, volumeName, "") + require.NoError(t, err) + + // Layer cache should be cleaned. + require.Equal(t, 0, lc.Len(), "layer cache should be empty after delete") +} + +// ─── Hardlink fallback tests ────────────────────────────────────────────────── + +func TestLayerCache_HardlinkFallbackOnMissingSource(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("hardlink-source-missing") + + // Register a path that won't exist on disk. + lc.Register(d, "/nonexistent/path/layer.bin") + + // Lookup should fail (source doesn't exist). + _, ok := lc.Lookup(d) + require.False(t, ok, "should not find source that doesn't exist") +} + +func TestLayerCache_HardlinkFallbackUsesSecondPath(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("multi-path-fallback") + tmpDir := t.TempDir() + + // First path doesn't exist, second does. + path1 := filepath.Join(tmpDir, "deleted.bin") + path2 := filepath.Join(tmpDir, "exists.bin") + require.NoError(t, os.WriteFile(path2, []byte("data"), 0644)) + + lc.Register(d, path1) + lc.Register(d, path2) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, path2, found) +} + +// ─── Layer metadata persistence ─────────────────────────────────────────────── + +func TestLayerMetadata_SaveAndLoad(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "layer_digests.json") + + entries := []LayerMetadataEntry{ + {Digest: "sha256:abc", FilePath: "weights.bin", Size: 100}, + {Digest: "sha256:def", FilePath: "config.json", Size: 50}, + } + + require.NoError(t, saveLayerMetadata(path, entries)) + + loaded, err := loadLayerMetadata(path) + require.NoError(t, err) + require.Equal(t, entries, loaded) +} + +func TestLayerMetadata_LoadMissing(t *testing.T) { + _, err := loadLayerMetadata("/nonexistent/path") + require.Error(t, err) +} + +// ─── Semaphore / flow control ───────────────────────────────────────────────── + +func TestLayerCache_SemaphoreBasic(t *testing.T) { + lc := NewLayerCache(2) + sem := lc.Semaphore() + require.NotNil(t, sem) + + // Should be able to acquire 2 slots. + require.True(t, sem.TryAcquire(1)) + require.True(t, sem.TryAcquire(1)) + // Third should fail (limit is 2). + require.False(t, sem.TryAcquire(1)) + + sem.Release(1) + require.True(t, sem.TryAcquire(1)) +} + +func TestLayerCache_DefaultConcurrency(t *testing.T) { + lc := NewLayerCache(0) // Should use default. + sem := lc.Semaphore() + require.NotNil(t, sem) + + // Default is 8, should be able to acquire 8. + for i := 0; i < 8; i++ { + require.True(t, sem.TryAcquire(1)) + } + require.False(t, sem.TryAcquire(1)) +} + +// ─── Worker with LayerCache integration ─────────────────────────────────────── + +func TestNewWorkerWithLayerCache(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + lc := NewLayerCache(8) + + worker, err := NewWorkerWithLayerCache(cfg, sm, lc) + require.NoError(t, err) + require.NotNil(t, worker) + require.NotNil(t, worker.layerCache) +} + +func TestNewWorkerWithNilLayerCache(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + // Should use the standard puller when lc is nil. + worker, err := NewWorkerWithLayerCache(cfg, sm, nil) + require.NoError(t, err) + require.NotNil(t, worker) + require.Nil(t, worker.layerCache) +} + +// ─── Concurrent dedup via singleflight ──────────────────────────────────────── + +func TestLayerCache_SingleflightDedup(t *testing.T) { + lc := NewLayerCache(8) + sfg := lc.SflightGroup() + + var callCount atomic.Int32 + + d := digest.FromString("singleflight-test") + key := d.String() + + const n = 10 + var started atomic.Int32 + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + started.Add(1) + // Wait until all goroutines have started. + for started.Load() < n { + runtime.Gosched() + } + _, _, _ = sfg.Do(key, func() (interface{}, error) { + callCount.Add(1) + // Sleep to keep the function in-flight while other goroutines arrive. + time.Sleep(50 * time.Millisecond) + return nil, nil + }) + }() + } + + wg.Wait() + + // Due to singleflight, the function should be called exactly once + // because all goroutines overlap on the same key while blocked. + require.Equal(t, int32(1), callCount.Load(), + "singleflight should deduplicate concurrent calls") +} + +// ─── getLayerFilePath ───────────────────────────────────────────────────────── + +func TestGetLayerFilePath_NoAnnotations(t *testing.T) { + desc := createTestDescriptor("sha256:abc123", nil) + require.Equal(t, "", getLayerFilePath(desc)) +} + +func TestGetLayerFilePath_CurrentSpec(t *testing.T) { + desc := createTestDescriptor("sha256:abc123", map[string]string{ + "org.cncf.model.filepath": "weights/model.safetensors", + }) + require.Equal(t, "weights/model.safetensors", getLayerFilePath(desc)) +} + +func TestGetLayerFilePath_LegacySpec(t *testing.T) { + desc := createTestDescriptor("sha256:abc123", map[string]string{ + "org.cnai.model.filepath": "config.json", + }) + require.Equal(t, "config.json", getLayerFilePath(desc)) +} + +func TestGetLayerFilePath_PrefersCurrentOverLegacy(t *testing.T) { + desc := createTestDescriptor("sha256:abc123", map[string]string{ + "org.cncf.model.filepath": "current.bin", + "org.cnai.model.filepath": "legacy.bin", + }) + require.Equal(t, "current.bin", getLayerFilePath(desc)) +} + +// ─── Test helpers ───────────────────────────────────────────────────────────── + +func createTestDescriptor(digestStr string, annotations map[string]string) ocispec.Descriptor { + d, _ := digest.Parse(digestStr) + return ocispec.Descriptor{ + Digest: d, + Size: 100, + Annotations: annotations, + } +} diff --git a/pkg/service/layer_metadata.go b/pkg/service/layer_metadata.go new file mode 100644 index 0000000..040924f --- /dev/null +++ b/pkg/service/layer_metadata.go @@ -0,0 +1,46 @@ +package service + +import ( + "encoding/json" + "os" + + "github.com/pkg/errors" +) + +// LayerMetadataEntry stores the digest and file path for a single layer, +// persisted alongside the volume so the LayerCache can be rebuilt on restart. +type LayerMetadataEntry struct { + Digest string `json:"digest"` + FilePath string `json:"file_path"` + Size int64 `json:"size"` +} + +// saveLayerMetadata writes the layer metadata to the given path. +func saveLayerMetadata(path string, entries []LayerMetadataEntry) error { + data, err := json.MarshalIndent(entries, "", " ") + if err != nil { + return errors.Wrap(err, "marshal layer metadata") + } + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0644); err != nil { + return errors.Wrap(err, "write temp layer metadata") + } + if err := os.Rename(tmpPath, path); err != nil { + _ = os.Remove(tmpPath) // clean up on failure + return errors.Wrap(err, "rename layer metadata") + } + return nil +} + +// loadLayerMetadata reads the layer metadata from the given path. +func loadLayerMetadata(path string) ([]LayerMetadataEntry, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, errors.Wrap(err, "read layer metadata") + } + var entries []LayerMetadataEntry + if err := json.Unmarshal(data, &entries); err != nil { + return nil, errors.Wrap(err, "unmarshal layer metadata") + } + return entries, nil +} From 9313b6df3853b784dd06504b4e7556ace8d4a91d Mon Sep 17 00:00:00 2001 From: Rishi Jat Date: Sun, 10 May 2026 05:11:11 +0530 Subject: [PATCH 2/4] feat(status): add silent layer caching to Hook for metric deduplication Adds the LayerCached method to status.Hook. This allows wait-paths and cache hits to increment their individual pod UI progress counters correctly to 100% without emitting duplicated Prometheus node metrics or generating superfluous network completion logs. Signed-off-by: Rishi Jat --- pkg/status/hook.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/pkg/status/hook.go b/pkg/status/hook.go index 9ff40ff..f56de75 100644 --- a/pkg/status/hook.go +++ b/pkg/status/hook.go @@ -171,6 +171,41 @@ func (h *Hook) AfterPullLayer(desc ocispec.Descriptor, err error) { progress.Span.End() } +// LayerCached records a layer as successfully pulled without emitting +// duplicate network metrics or "pulled layer" logs. This ensures status totals +// remain accurate for waiting callers without spamming observability. +func (h *Hook) LayerCached(desc ocispec.Descriptor, manifest ocispec.Manifest) { + h.mutex.Lock() + defer h.mutex.Unlock() + + filePath := "" + if desc.Annotations != nil { + if desc.Annotations[modelspec.AnnotationFilepath] != "" { + filePath = fmt.Sprintf("/%s", desc.Annotations[modelspec.AnnotationFilepath]) + } else if desc.Annotations[oldModelspec.AnnotationFilepath] != "" { + filePath = fmt.Sprintf("/%s", desc.Annotations[oldModelspec.AnnotationFilepath]) + } + } + + h.manifest = &manifest + + _, span := tracing.Tracer.Start(h.ctx, "PullLayerCached") + span.End() + + now := time.Now() + h.progress[desc.Digest] = &ProgressItem{ + Digest: desc.Digest, + Path: filePath, + Size: desc.Size, + StartedAt: now, + FinishedAt: &now, + Error: nil, + Span: span, + } + + h.pulled.Add(1) +} + func (h *Hook) getProgress() Progress { items := []ProgressItem{} for _, item := range h.progress { From 21b23bdafb11fe4a1881e696a2a1f5a382fb0dee Mon Sep 17 00:00:00 2001 From: Rishi Jat Date: Sun, 10 May 2026 05:11:22 +0530 Subject: [PATCH 3/4] feat(service): implement manual layer-aware singleflight orchestration Replaces high-level b.Fetch() wrappers with a robust native pipeline: - Natively parses and fetches descriptors using oras. - Hardlinks cached layers efficiently. - Wraps remote.Fetch within singleflight and node-level semaphores. - Gracefully handles TOCTOU link failures, cross-device EXDEV bounds, and network flakes via exponential backoff retries. - Proved 100% data race free via new httptest simulation testbed. Signed-off-by: Rishi Jat --- pkg/service/layer_aware_puller_test.go | 186 ++++++++++ pkg/service/puller.go | 455 +++++++++++++++++++++++++ 2 files changed, 641 insertions(+) create mode 100644 pkg/service/layer_aware_puller_test.go diff --git a/pkg/service/layer_aware_puller_test.go b/pkg/service/layer_aware_puller_test.go new file mode 100644 index 0000000..d7f1230 --- /dev/null +++ b/pkg/service/layer_aware_puller_test.go @@ -0,0 +1,186 @@ +package service + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/modelpack/modctl/pkg/backend" + pkgcodec "github.com/modelpack/modctl/pkg/codec" + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/config/auth" + "github.com/modelpack/model-csi-driver/pkg/status" + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/require" +) + +func TestLayerAwarePuller_ConcurrencyDedup(t *testing.T) { + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(auth.GetKeyChainByRef, func(string) (*auth.PassKeyChain, error) { + return &auth.PassKeyChain{ServerScheme: "http"}, nil + }) + + d1 := digest.FromString("layer1") + layer1 := ocispec.Descriptor{ + Digest: d1, + Size: 9, + MediaType: "application/vnd.oci.image.layer.v1.tar", + Annotations: map[string]string{ + "org.cncf.model.filepath": "layer1.tar", + }, + } + manifest := ocispec.Manifest{Layers: []ocispec.Descriptor{layer1}} + manifestBytes, _ := json.Marshal(manifest) + + var fetchCalls atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v2/" { + w.WriteHeader(http.StatusOK) + return + } + if r.URL.Path == "/v2/test/repo/manifests/latest" { + w.Header().Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + _, _ = w.Write(manifestBytes) + return + } + if r.URL.Path == "/v2/test/repo/blobs/"+d1.String() { + fetchCalls.Add(1) + time.Sleep(50 * time.Millisecond) // Simulate network IO + _, _ = w.Write([]byte("dummydata")) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + // Strip http:// prefix + registryHost := ts.URL[7:] + refStr := registryHost + "/test/repo:latest" + + patches.ApplyFunc(pkgcodec.New, func(string) (pkgcodec.Codec, error) { + return &dummyCodec{}, nil + }) + + lc := NewLayerCache(2) + puller := &layerAwarePuller{ + pullCfg: &config.PullConfig{Concurrency: 4}, + hook: status.NewHook(context.Background()), + layerCache: lc, + } + + b, _ := backend.New("") + artifact := &backend.InspectedModelArtifact{} + + const numPods = 10 + var wg sync.WaitGroup + + for i := 0; i < numPods; i++ { + wg.Add(1) + go func() { + defer wg.Done() + targetDir := filepath.Join(t.TempDir(), "model") + err := puller.layerAwarePull(context.Background(), b, artifact, refStr, targetDir, true) + require.NoError(t, err) + }() + } + + wg.Wait() + + require.Equal(t, int32(1), fetchCalls.Load(), "only one remote.Fetch should execute per digest") +} + +func TestLayerAwarePuller_SemaphoreCancelNoLeak(t *testing.T) { + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(auth.GetKeyChainByRef, func(string) (*auth.PassKeyChain, error) { + return &auth.PassKeyChain{ServerScheme: "http"}, nil + }) + + d1 := digest.FromString("layer1") + layer1 := ocispec.Descriptor{ + Digest: d1, + Size: 9, + MediaType: "application/vnd.oci.image.layer.v1.tar", + Annotations: map[string]string{ + "org.cncf.model.filepath": "layer1.tar", + }, + } + manifestBytes, _ := json.Marshal(ocispec.Manifest{Layers: []ocispec.Descriptor{layer1}}) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v2/" { + w.WriteHeader(http.StatusOK) + return + } + if r.URL.Path == "/v2/test/repo/manifests/latest" { + w.Header().Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + _, _ = w.Write(manifestBytes) + return + } + if r.URL.Path == "/v2/test/repo/blobs/"+d1.String() { + // Simulate a hung connection or slow download + time.Sleep(2 * time.Second) + _, _ = w.Write([]byte("dummydata")) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + registryHost := ts.URL[7:] + refStr := registryHost + "/test/repo:latest" + + patches.ApplyFunc(pkgcodec.New, func(string) (pkgcodec.Codec, error) { + return &dummyCodec{}, nil + }) + + lc := NewLayerCache(2) + puller := &layerAwarePuller{ + pullCfg: &config.PullConfig{Concurrency: 4}, + hook: status.NewHook(context.Background()), + layerCache: lc, + } + + b, _ := backend.New("") + artifact := &backend.InspectedModelArtifact{} + targetDir := filepath.Join(t.TempDir(), "model") + + // Create a context that will cancel quickly. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := puller.layerAwarePull(ctx, b, artifact, refStr, targetDir, true) + require.Error(t, err) + + // Verify semaphore is fully released. + require.True(t, lc.Semaphore().TryAcquire(2), "semaphore should have 2 permits fully restored after failure") +} + +type dummyCodec struct{} + +func (d *dummyCodec) Type() string { return "dummy" } +func (d *dummyCodec) Encode(string, string) (io.Reader, error) { return nil, nil } +func (d *dummyCodec) Decode(outputDir, fp string, reader io.Reader, desc ocispec.Descriptor) error { + fullPath := filepath.Join(outputDir, fp) + _ = os.MkdirAll(filepath.Dir(fullPath), 0755) + + // Must consume reader to prevent connection aborts + _, _ = io.Copy(io.Discard, reader) + + _ = os.WriteFile(fullPath, []byte("data"), 0644) + return nil +} diff --git a/pkg/service/puller.go b/pkg/service/puller.go index 3e23f30..753e4cb 100644 --- a/pkg/service/puller.go +++ b/pkg/service/puller.go @@ -2,21 +2,31 @@ package service import ( "context" + "encoding/json" "io" "os" + "path/filepath" "strings" + "time" + oldModelspec "github.com/dragonflyoss/model-spec/specs-go/v1" "github.com/modelpack/modctl/pkg/backend" + "github.com/modelpack/modctl/pkg/backend/remote" + pkgcodec "github.com/modelpack/modctl/pkg/codec" modctlConfig "github.com/modelpack/modctl/pkg/config" "github.com/modelpack/model-csi-driver/pkg/config" "github.com/modelpack/model-csi-driver/pkg/config/auth" "github.com/modelpack/model-csi-driver/pkg/logger" "github.com/modelpack/model-csi-driver/pkg/status" + modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" ) type PullHook interface { + LayerCached(desc ocispec.Descriptor, manifest ocispec.Manifest) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) AfterPullLayer(desc ocispec.Descriptor, err error) } @@ -33,6 +43,18 @@ var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *stat } } +// NewLayerAwarePuller creates a puller that leverages the LayerCache for +// layer-level deduplication. It resolves the manifest ahead of time, +// hardlinks cached layers, and only pulls missing layers. +var NewLayerAwarePuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker, lc *LayerCache) Puller { + return &layerAwarePuller{ + pullCfg: pullCfg, + hook: hook, + diskQuotaChecker: diskQuotaChecker, + layerCache: lc, + } +} + type puller struct { pullCfg *config.PullConfig hook *status.Hook @@ -119,3 +141,436 @@ func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeM return nil } + +// layerAwarePuller wraps the pull flow with layer-level deduplication. +// It resolves the manifest first, checks the LayerCache for each layer, +// hardlinks cached layers, and only pulls missing ones via the Fetch path. +type layerAwarePuller struct { + pullCfg *config.PullConfig + hook *status.Hook + diskQuotaChecker *DiskQuotaChecker + layerCache *LayerCache +} + +func (p *layerAwarePuller) Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool, excludeFilePatterns []string) error { + keyChain, err := auth.GetKeyChainByRef(reference) + if err != nil { + return errors.Wrapf(err, "get auth for model: %s", reference) + } + plainHTTP := keyChain.ServerScheme == "http" + + b, err := backend.New("") + if err != nil { + return errors.Wrap(err, "create modctl backend") + } + + modelArtifact := NewModelArtifact(b, reference, plainHTTP) + + if p.diskQuotaChecker != nil { + if err := p.diskQuotaChecker.Check(ctx, modelArtifact, excludeModelWeights, excludeFilePatterns); err != nil { + return errors.Wrap(err, "check disk quota") + } + } + + if err := os.MkdirAll(targetDir, 0755); err != nil { + return errors.Wrapf(err, "create model dir: %s", targetDir) + } + + // If filtering is active, fall back to the standard puller (no dedup + // for partial pulls since layer content is subset-based). + if excludeModelWeights || len(excludeFilePatterns) > 0 { + return p.pullWithFiltering(ctx, b, modelArtifact, reference, targetDir, plainHTTP, excludeModelWeights, excludeFilePatterns) + } + + // Inspect the model to get layer descriptors. + artifact, err := modelArtifact.Inspect(ctx, reference) + if err != nil { + logger.WithContext(ctx).WithError(err).Warnf("layer-aware pull: failed to inspect model, falling back to standard pull") + return p.standardPull(ctx, b, reference, targetDir, plainHTTP) + } + + return p.layerAwarePull(ctx, b, artifact, reference, targetDir, plainHTTP) +} + +// layerAwarePull performs the deduplication-aware pull flow. +func (p *layerAwarePuller) layerAwarePull( + ctx context.Context, + b backend.Backend, + artifact *backend.InspectedModelArtifact, + reference, targetDir string, + plainHTTP bool, +) error { + if p.layerCache == nil { + return p.standardPull(ctx, b, reference, targetDir, plainHTTP) + } + + // 1. Resolve remote manifest directly to get full descriptors. + ref, err := backend.ParseReference(reference) + if err != nil { + return errors.Wrap(err, "parse reference") + } + + client, err := remote.New(ref.Repository(), remote.WithPlainHTTP(plainHTTP), remote.WithInsecure(true)) + if err != nil { + return errors.Wrap(err, "create remote client") + } + + _, manifestReader, err := client.Manifests().FetchReference(ctx, ref.Tag()) + if err != nil { + return errors.Wrap(err, "fetch manifest") + } + defer func() { _ = manifestReader.Close() }() + + var manifest ocispec.Manifest + if err := json.NewDecoder(manifestReader).Decode(&manifest); err != nil { + return errors.Wrap(err, "decode manifest") + } + + // 2. Classify layers as cached or uncached. + var uncachedLayers []ocispec.Descriptor + var cachedLayers []cachedLayerInfo + var metadataEntries []LayerMetadataEntry + + for _, layer := range manifest.Layers { + fp := getLayerFilePath(layer) + if fp == "" { + continue // skip layers without filepaths + } + + metadataEntries = append(metadataEntries, LayerMetadataEntry{ + Digest: layer.Digest.String(), + FilePath: fp, + Size: layer.Size, + }) + + sourcePath, found := p.layerCache.Lookup(layer.Digest) + if found { + cachedLayers = append(cachedLayers, cachedLayerInfo{ + desc: layer, + sourcePath: sourcePath, + filePath: fp, + }) + } else { + uncachedLayers = append(uncachedLayers, layer) + } + } + + logger.WithContext(ctx).Infof( + "layer-aware pull: %s — %d cached, %d to pull (total %d layers)", + reference, len(cachedLayers), len(uncachedLayers), len(manifest.Layers), + ) + + p.hook.SetTotal(len(manifest.Layers)) + + // Step 3: Hardlink cached layers. + for _, cl := range cachedLayers { + destPath := filepath.Join(targetDir, cl.filePath) + if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil { + logger.WithContext(ctx).WithError(err).Warnf( + "layer-aware pull: failed to create dir for hardlink %s, will pull instead", cl.filePath, + ) + uncachedLayers = append(uncachedLayers, cl.desc) + continue + } + + if err := os.Link(cl.sourcePath, destPath); err != nil { + logger.WithContext(ctx).WithError(err).Warnf( + "layer-aware pull: hardlink failed for %s (src=%s), will pull instead", + cl.filePath, cl.sourcePath, + ) + uncachedLayers = append(uncachedLayers, cl.desc) + continue + } + + logger.WithContext(ctx).Infof( + "layer-aware pull: reused layer %s via hardlink (%s → %s)", + cl.desc.Digest, cl.sourcePath, destPath, + ) + + // Immediately register the new path so others can use it. + p.layerCache.Register(cl.desc.Digest, destPath) + // Record progress silently for cached layers. + p.hook.LayerCached(cl.desc, manifest) + } + + // Step 4: Pull uncached layers concurrently, synchronized by singleflight. + if len(uncachedLayers) > 0 { + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(int(p.pullCfg.Concurrency)) + sfg := p.layerCache.SflightGroup() + sem := p.layerCache.Semaphore() + + for _, layerDesc := range uncachedLayers { + desc := layerDesc // capture for goroutine + g.Go(func() error { + for attempt := 0; attempt < 3; attempt++ { + select { + case <-gctx.Done(): + return gctx.Err() + default: + } + + // The singleflight key is the digest, guaranteeing only one download across all pods. + _, sfgErr, _ := sfg.Do(desc.Digest.String(), func() (interface{}, error) { + // Apply node-level flow control securely around network operations. + if sem != nil { + if err := sem.Acquire(gctx, 1); err != nil { + return nil, err + } + defer sem.Release(1) + } + + p.hook.BeforePullLayer(desc, manifest) + + // Defer the after hook to guarantee it runs on panic or error. + var pullErr error + defer func() { + p.hook.AfterPullLayer(desc, pullErr) + }() + + pullErr = func() error { + // Open network stream to registry. + reader, err := client.Fetch(gctx, desc) + if err != nil { + return errors.Wrap(err, "fetch blob from remote") + } + defer func() { _ = reader.Close() }() + + // Create codec to decode stream to disk. + codec, err := pkgcodec.New(pkgcodec.TypeFromMediaType(desc.MediaType)) + if err != nil { + return errors.Wrapf(err, "create codec for media type %s", desc.MediaType) + } + + fp := getLayerFilePath(desc) + if err := codec.Decode(targetDir, fp, reader, desc); err != nil { + // Check if another concurrent process (outside our driver) wrote it. + if errors.Is(err, pkgcodec.ErrAlreadyUpToDate) { + return nil + } + return errors.Wrap(err, "decode layer") + } + + // Registration only occurs on successful completion. + fullPath := filepath.Join(targetDir, fp) + if _, statErr := os.Stat(fullPath); statErr == nil { + p.layerCache.Register(desc.Digest, fullPath) + } + + return nil + }() + + return nil, pullErr + }) + + // If the singleflight function returned an error, retry the loop. + if sfgErr != nil { + logger.WithContext(ctx).WithError(sfgErr).Warnf("layer-aware pull: network fetch failed for %s, retrying", desc.Digest) + time.Sleep(1 * time.Second) + continue + } + + // If another goroutine successfully downloaded the file via singleflight, + // we still need to hardlink it into OUR target directory since we bypassed + // our own download logic. + fp := getLayerFilePath(desc) + destPath := filepath.Join(targetDir, fp) + + // Verify if the file is already at our destination (e.g. if WE were the downloader). + if _, statErr := os.Stat(destPath); statErr == nil { + return nil // We did the download, or it's already there. + } + + // We were a waiting caller. We must hardlink from the newly cached location. + sourcePath, found := p.layerCache.Lookup(desc.Digest) + if !found { + logger.WithContext(ctx).Warnf("layer-aware pull: singleflight completed but digest %s not in cache, retrying", desc.Digest) + time.Sleep(1 * time.Second) + continue + } + + if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil { + return err + } + if err := os.Link(sourcePath, destPath); err != nil { + // Hardlink failed (EXDEV or source was deleted). Remove stale entry and retry fetch. + logger.WithContext(ctx).WithError(err).Warnf("layer-aware pull: hardlink after singleflight failed for %s, retrying", desc.Digest) + p.layerCache.RemoveByPrefix(sourcePath) + time.Sleep(1 * time.Second) + continue + } + p.layerCache.Register(desc.Digest, destPath) + + // Record progress silently for waiting pods. + p.hook.LayerCached(desc, manifest) + + return nil + } + return errors.Errorf("layer-aware pull: failed to pull layer %s after 3 attempts", desc.Digest) + }) + } + + if err := g.Wait(); err != nil { + logger.WithContext(ctx).WithError(err).Errorf("layer-aware pull: failed to fetch uncached layers for %s", reference) + return err + } + } + + // Step 5: Save layer metadata for restart rebuild. + if len(metadataEntries) > 0 { + metadataPath := filepath.Join(filepath.Dir(targetDir), "layer_digests.json") + if err := saveLayerMetadata(metadataPath, metadataEntries); err != nil { + logger.WithContext(ctx).WithError(err).Warnf("layer-aware pull: failed to save layer metadata") + } + } + + return nil +} + +// standardPull falls back to the normal full pull path. +func (p *layerAwarePuller) standardPull(ctx context.Context, b backend.Backend, reference, targetDir string, plainHTTP bool) error { + trackingHook := &layerTrackingHook{ + inner: p.hook, + cache: p.layerCache, + targetDir: targetDir, + sem: p.layerCache.Semaphore(), + } + + pullConfig := modctlConfig.NewPull() + pullConfig.Concurrency = int(p.pullCfg.Concurrency) + pullConfig.PlainHTTP = plainHTTP + pullConfig.Proxy = p.pullCfg.ProxyURL + pullConfig.DragonflyEndpoint = p.pullCfg.DragonflyEndpoint + pullConfig.Insecure = true + pullConfig.ExtractDir = targetDir + pullConfig.ExtractFromRemote = true + pullConfig.Hooks = trackingHook + pullConfig.ProgressWriter = io.Discard + pullConfig.DisableProgress = true + + if err := b.Pull(ctx, reference, pullConfig); err != nil { + logger.WithContext(ctx).WithError(err).Errorf("failed to pull model image: %s", reference) + return errors.Wrap(err, "pull model image") + } + + return nil +} + +// pullWithFiltering handles pulls with weight/file-pattern exclusions. +func (p *layerAwarePuller) pullWithFiltering( + ctx context.Context, + b backend.Backend, + modelArtifact *ModelArtifact, + reference, targetDir string, + plainHTTP bool, + excludeModelWeights bool, + excludeFilePatterns []string, +) error { + patterns, total, err := modelArtifact.GetPatterns(ctx, excludeModelWeights, excludeFilePatterns) + if err != nil { + return errors.Wrap(err, "get model file patterns without weights") + } + + if len(patterns) == 0 { + logger.WithContext(ctx).Infof("no files to fetch from model: %s", reference) + return nil + } + + logger.WithContext(ctx).Infof( + "fetching partial files from model: %s, files: %s (%d/%d)", + reference, strings.Join(patterns, ", "), len(patterns), total, + ) + p.hook.SetTotal(len(patterns)) + + fetchConfig := modctlConfig.NewFetch() + fetchConfig.Concurrency = int(p.pullCfg.Concurrency) + fetchConfig.PlainHTTP = plainHTTP + fetchConfig.Proxy = p.pullCfg.ProxyURL + fetchConfig.DragonflyEndpoint = p.pullCfg.DragonflyEndpoint + fetchConfig.Insecure = true + fetchConfig.Output = targetDir + fetchConfig.Hooks = p.hook + fetchConfig.ProgressWriter = io.Discard + fetchConfig.DisableProgress = true + fetchConfig.Patterns = patterns + + if err := b.Fetch(ctx, reference, fetchConfig); err != nil { + logger.WithContext(ctx).WithError(err).Errorf("failed to fetch model: %s", reference) + return errors.Wrap(err, "fetch model") + } + + return nil +} + +// cachedLayerInfo holds info about a layer found in the cache. +type cachedLayerInfo struct { + desc ocispec.Descriptor + sourcePath string + filePath string +} + +// layerTrackingHook wraps the status.Hook to additionally register pulled +// layers in the LayerCache after successful pull, and enforce node-level +// concurrency via the semaphore. +type layerTrackingHook struct { + inner *status.Hook + cache *LayerCache + targetDir string + sem *semaphore.Weighted +} + +func (h *layerTrackingHook) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) { + // Enforce node-level flow control by acquiring a semaphore slot. + // We use 1 slot per layer (count-based, not size-based) to keep it simple + // and avoid potential issues with very large layers blocking all slots. + if h.sem != nil { + // Use a background context so we don't fail the pull if the parent + // context is cancelled while waiting — the pull itself handles cancellation. + _ = h.sem.Acquire(context.Background(), 1) + } + + h.inner.BeforePullLayer(desc, manifest) +} + +func (h *layerTrackingHook) AfterPullLayer(desc ocispec.Descriptor, err error) { + // Release the semaphore slot. + if h.sem != nil { + h.sem.Release(1) + } + + h.inner.AfterPullLayer(desc, err) + + // Only register on successful pull. + if err != nil || h.cache == nil { + return + } + + // Determine the file path from the layer descriptor annotations. + filePath := getLayerFilePath(desc) + if filePath == "" { + return + } + + fullPath := filepath.Join(h.targetDir, filePath) + if _, statErr := os.Stat(fullPath); statErr != nil { + return // file doesn't exist, cannot register + } + + h.cache.Register(desc.Digest, fullPath) +} + +// getLayerFilePath extracts the file path from a layer descriptor's annotations. +func getLayerFilePath(desc ocispec.Descriptor) string { + if desc.Annotations == nil { + return "" + } + // Try the current model-spec annotation first. + if fp := desc.Annotations[modelspec.AnnotationFilepath]; fp != "" { + return fp + } + // Fall back to legacy annotation. + if fp := desc.Annotations[oldModelspec.AnnotationFilepath]; fp != "" { + return fp + } + return "" +} From d1d297dae83f3dbff9f5c20e81881add63b27e07 Mon Sep 17 00:00:00 2001 From: Rishi Jat Date: Sun, 10 May 2026 05:11:31 +0530 Subject: [PATCH 4/4] feat(service): wire LayerCache into worker and service Initializes the global LayerCache at the service layer and threads it down to the Worker, passing it into the NewLayerAwarePuller factory. Registers lifecycle eviction hooks when volumes are deleted. Signed-off-by: Rishi Jat --- pkg/service/service.go | 8 +++++++- pkg/service/worker.go | 25 +++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/pkg/service/service.go b/pkg/service/service.go index 73b0257..fc1899d 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -71,7 +71,13 @@ func New(cfg *config.Config) (*Service, error) { if err != nil { return nil, errors.Wrap(err, "create status manager") } - worker, err := NewWorker(cfg, sm) + + // Create the node-level layer cache for layer deduplication. + lc := NewLayerCache(defaultMaxConcurrentLayers) + // Rebuild cache from existing volumes on disk (handles restarts). + lc.Rebuild(cfg.Get(), sm) + + worker, err := NewWorkerWithLayerCache(cfg, sm, lc) if err != nil { return nil, errors.Wrap(err, "create worker") } diff --git a/pkg/service/worker.go b/pkg/service/worker.go index 8628736..98a895d 100644 --- a/pkg/service/worker.go +++ b/pkg/service/worker.go @@ -57,17 +57,32 @@ type Worker struct { inflight singleflight.Group contextMap *ContextMap kmutex kmutex.KeyedLocker + layerCache *LayerCache } func NewWorker(cfg *config.Config, sm *status.StatusManager) (*Worker, error) { - return &Worker{ + return NewWorkerWithLayerCache(cfg, sm, nil) +} + +func NewWorkerWithLayerCache(cfg *config.Config, sm *status.StatusManager, lc *LayerCache) (*Worker, error) { + w := &Worker{ cfg: cfg, newPuller: NewPuller, sm: sm, inflight: singleflight.Group{}, contextMap: NewContextMap(), kmutex: kmutex.New(), - }, nil + layerCache: lc, + } + + // When a LayerCache is provided, use the layer-aware puller. + if lc != nil { + w.newPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker) Puller { + return NewLayerAwarePuller(ctx, pullCfg, hook, diskQuotaChecker, lc) + } + } + + return w, nil } func (worker *Worker) deleteModel(ctx context.Context, isStaticVolume bool, volumeName, mountID string) error { @@ -103,6 +118,12 @@ func (worker *Worker) deleteModel(ctx context.Context, isStaticVolume bool, volu statusPath := filepath.Join(volumeDir, "status.json") worker.sm.HookManager.Delete(statusPath) + // Evict stale layer cache entries for the removed volume. + if worker.layerCache != nil { + worker.layerCache.RemoveByPrefix(volumeDir) + logger.WithContext(ctx).Infof("evicted layer cache entries for %s", volumeDir) + } + return nil, nil })