diff --git a/pkg/service/layer_aware_puller_test.go b/pkg/service/layer_aware_puller_test.go new file mode 100644 index 0000000..d7f1230 --- /dev/null +++ b/pkg/service/layer_aware_puller_test.go @@ -0,0 +1,186 @@ +package service + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/modelpack/modctl/pkg/backend" + pkgcodec "github.com/modelpack/modctl/pkg/codec" + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/config/auth" + "github.com/modelpack/model-csi-driver/pkg/status" + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/require" +) + +func TestLayerAwarePuller_ConcurrencyDedup(t *testing.T) { + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(auth.GetKeyChainByRef, func(string) (*auth.PassKeyChain, error) { + return &auth.PassKeyChain{ServerScheme: "http"}, nil + }) + + d1 := digest.FromString("layer1") + layer1 := ocispec.Descriptor{ + Digest: d1, + Size: 9, + MediaType: "application/vnd.oci.image.layer.v1.tar", + Annotations: map[string]string{ + "org.cncf.model.filepath": "layer1.tar", + }, + } + manifest := ocispec.Manifest{Layers: []ocispec.Descriptor{layer1}} + manifestBytes, _ := json.Marshal(manifest) + + var fetchCalls atomic.Int32 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v2/" { + w.WriteHeader(http.StatusOK) + return + } + if r.URL.Path == "/v2/test/repo/manifests/latest" { + w.Header().Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + _, _ = w.Write(manifestBytes) + return + } + if r.URL.Path == "/v2/test/repo/blobs/"+d1.String() { + fetchCalls.Add(1) + time.Sleep(50 * time.Millisecond) // Simulate network IO + _, _ = w.Write([]byte("dummydata")) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + // Strip http:// prefix + registryHost := ts.URL[7:] + refStr := registryHost + "/test/repo:latest" + + patches.ApplyFunc(pkgcodec.New, func(string) (pkgcodec.Codec, error) { + return &dummyCodec{}, nil + }) + + lc := NewLayerCache(2) + puller := &layerAwarePuller{ + pullCfg: &config.PullConfig{Concurrency: 4}, + hook: status.NewHook(context.Background()), + layerCache: lc, + } + + b, _ := backend.New("") + artifact := &backend.InspectedModelArtifact{} + + const numPods = 10 + var wg sync.WaitGroup + + for i := 0; i < numPods; i++ { + wg.Add(1) + go func() { + defer wg.Done() + targetDir := filepath.Join(t.TempDir(), "model") + err := puller.layerAwarePull(context.Background(), b, artifact, refStr, targetDir, true) + require.NoError(t, err) + }() + } + + wg.Wait() + + require.Equal(t, int32(1), fetchCalls.Load(), "only one remote.Fetch should execute per digest") +} + +func TestLayerAwarePuller_SemaphoreCancelNoLeak(t *testing.T) { + patches := gomonkey.NewPatches() + defer patches.Reset() + + patches.ApplyFunc(auth.GetKeyChainByRef, func(string) (*auth.PassKeyChain, error) { + return &auth.PassKeyChain{ServerScheme: "http"}, nil + }) + + d1 := digest.FromString("layer1") + layer1 := ocispec.Descriptor{ + Digest: d1, + Size: 9, + MediaType: "application/vnd.oci.image.layer.v1.tar", + Annotations: map[string]string{ + "org.cncf.model.filepath": "layer1.tar", + }, + } + manifestBytes, _ := json.Marshal(ocispec.Manifest{Layers: []ocispec.Descriptor{layer1}}) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v2/" { + w.WriteHeader(http.StatusOK) + return + } + if r.URL.Path == "/v2/test/repo/manifests/latest" { + w.Header().Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + _, _ = w.Write(manifestBytes) + return + } + if r.URL.Path == "/v2/test/repo/blobs/"+d1.String() { + // Simulate a hung connection or slow download + time.Sleep(2 * time.Second) + _, _ = w.Write([]byte("dummydata")) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + registryHost := ts.URL[7:] + refStr := registryHost + "/test/repo:latest" + + patches.ApplyFunc(pkgcodec.New, func(string) (pkgcodec.Codec, error) { + return &dummyCodec{}, nil + }) + + lc := NewLayerCache(2) + puller := &layerAwarePuller{ + pullCfg: &config.PullConfig{Concurrency: 4}, + hook: status.NewHook(context.Background()), + layerCache: lc, + } + + b, _ := backend.New("") + artifact := &backend.InspectedModelArtifact{} + targetDir := filepath.Join(t.TempDir(), "model") + + // Create a context that will cancel quickly. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := puller.layerAwarePull(ctx, b, artifact, refStr, targetDir, true) + require.Error(t, err) + + // Verify semaphore is fully released. + require.True(t, lc.Semaphore().TryAcquire(2), "semaphore should have 2 permits fully restored after failure") +} + +type dummyCodec struct{} + +func (d *dummyCodec) Type() string { return "dummy" } +func (d *dummyCodec) Encode(string, string) (io.Reader, error) { return nil, nil } +func (d *dummyCodec) Decode(outputDir, fp string, reader io.Reader, desc ocispec.Descriptor) error { + fullPath := filepath.Join(outputDir, fp) + _ = os.MkdirAll(filepath.Dir(fullPath), 0755) + + // Must consume reader to prevent connection aborts + _, _ = io.Copy(io.Discard, reader) + + _ = os.WriteFile(fullPath, []byte("data"), 0644) + return nil +} diff --git a/pkg/service/layer_cache.go b/pkg/service/layer_cache.go new file mode 100644 index 0000000..1411334 --- /dev/null +++ b/pkg/service/layer_cache.go @@ -0,0 +1,253 @@ +package service + +import ( + "os" + "path/filepath" + "strings" + "sync" + + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/logger" + "github.com/modelpack/model-csi-driver/pkg/status" + "github.com/opencontainers/go-digest" + "golang.org/x/sync/semaphore" + "golang.org/x/sync/singleflight" +) + +// defaultMaxConcurrentLayers is the default maximum number of layers that +// can be pulled concurrently across all volumes on a node. +const defaultMaxConcurrentLayers int64 = 8 + +// LayerCache maintains an in-memory mapping of layer digest → file paths +// on disk, enabling layer-level deduplication via hardlinks. +// +// Thread-safety: all methods are safe for concurrent use. +type LayerCache struct { + mu sync.RWMutex + // layers maps a layer digest to all known on-disk file paths containing + // that layer's content. Multiple volumes can reference the same layer. + layers map[digest.Digest][]string + + // sfGroup deduplicates concurrent pull requests for the same layer digest. + // The key is the digest string. + sfGroup singleflight.Group + + // sem controls the maximum number of concurrently in-flight layer pulls + // at the node level, preventing uncontrolled network and disk IO fan-out. + sem *semaphore.Weighted +} + +// NewLayerCache creates a new empty LayerCache with the given concurrency limit. +func NewLayerCache(maxConcurrentLayers int64) *LayerCache { + if maxConcurrentLayers <= 0 { + maxConcurrentLayers = defaultMaxConcurrentLayers + } + return &LayerCache{ + layers: make(map[digest.Digest][]string), + sem: semaphore.NewWeighted(maxConcurrentLayers), + } +} + +// Semaphore returns the weighted semaphore for node-level flow control. +func (lc *LayerCache) Semaphore() *semaphore.Weighted { + return lc.sem +} + +// SflightGroup returns the singleflight group for layer-level dedup. +func (lc *LayerCache) SflightGroup() *singleflight.Group { + return &lc.sfGroup +} + +// Register adds a file path for a given layer digest. +// If the path is already registered for the digest, this is a no-op. +func (lc *LayerCache) Register(d digest.Digest, path string) { + lc.mu.Lock() + defer lc.mu.Unlock() + + paths := lc.layers[d] + for _, p := range paths { + if p == path { + return // already registered + } + } + lc.layers[d] = append(paths, path) +} + +// Lookup returns an existing, valid file path for the given layer digest. +// It verifies the file still exists on disk before returning it. +// Returns ("", false) if no valid path is found. +func (lc *LayerCache) Lookup(d digest.Digest) (string, bool) { + lc.mu.RLock() + paths := lc.layers[d] + lc.mu.RUnlock() + + for _, p := range paths { + if _, err := os.Stat(p); err == nil { + return p, true + } + } + return "", false +} + +// Remove removes a specific path from all digest entries. +func (lc *LayerCache) Remove(path string) { + lc.mu.Lock() + defer lc.mu.Unlock() + + for d, paths := range lc.layers { + filtered := paths[:0] + for _, p := range paths { + if p != path { + filtered = append(filtered, p) + } + } + if len(filtered) == 0 { + delete(lc.layers, d) + } else { + lc.layers[d] = filtered + } + } +} + +// RemoveByPrefix removes all paths that have the given prefix from all digest +// entries. This is used during volume cleanup to evict all layer references +// under a volume directory. +func (lc *LayerCache) RemoveByPrefix(prefix string) { + lc.mu.Lock() + defer lc.mu.Unlock() + + for d, paths := range lc.layers { + filtered := paths[:0] + for _, p := range paths { + if !strings.HasPrefix(p, prefix) { + filtered = append(filtered, p) + } + } + if len(filtered) == 0 { + delete(lc.layers, d) + } else { + lc.layers[d] = filtered + } + } +} + +// Len returns the number of unique digests tracked. +func (lc *LayerCache) Len() int { + lc.mu.RLock() + defer lc.mu.RUnlock() + return len(lc.layers) +} + +// PathCount returns the total number of paths tracked across all digests. +func (lc *LayerCache) PathCount() int { + lc.mu.RLock() + defer lc.mu.RUnlock() + count := 0 + for _, paths := range lc.layers { + count += len(paths) + } + return count +} + +// Snapshot returns a copy of the current cache state for testing/debugging. +func (lc *LayerCache) Snapshot() map[digest.Digest][]string { + lc.mu.RLock() + defer lc.mu.RUnlock() + snapshot := make(map[digest.Digest][]string, len(lc.layers)) + for d, paths := range lc.layers { + cp := make([]string, len(paths)) + copy(cp, paths) + snapshot[d] = cp + } + return snapshot +} + +// Rebuild scans existing volume directories and re-populates the cache +// from disk. This is called on startup to restore layer cache state after +// a daemon restart or node reboot. +// +// For each volume with a successful pull status, we scan the model directory +// and register all files. The digest mapping is restored from the status +// metadata stored during previous pulls. +func (lc *LayerCache) Rebuild(cfg *config.RawConfig, sm *status.StatusManager) { + volumesDir := cfg.GetVolumesDir() + volumeDirs, err := os.ReadDir(volumesDir) + if err != nil { + if !os.IsNotExist(err) { + logger.Logger().WithError(err).Errorf("layer cache rebuild: read volumes dir %s", volumesDir) + } + return + } + + registered := 0 + for _, volumeDir := range volumeDirs { + if !volumeDir.IsDir() { + continue + } + volumeName := volumeDir.Name() + + if isStaticVolume(volumeName) { + n := lc.rebuildVolume(cfg.GetVolumeDir(volumeName), sm) + registered += n + } + + if isDynamicVolume(volumeName) { + modelsDir := cfg.GetModelsDirForDynamic(volumeName) + modelDirs, err := os.ReadDir(modelsDir) + if err != nil { + continue + } + for _, modelDir := range modelDirs { + if !modelDir.IsDir() { + continue + } + mountID := modelDir.Name() + n := lc.rebuildVolume(cfg.GetMountIDDirForDynamic(volumeName, mountID), sm) + registered += n + } + } + } + + logger.Logger().Infof("layer cache rebuild complete: %d layer-path entries registered, %d unique digests", + registered, lc.Len()) +} + +// rebuildVolume scans a single volume's model directory and registers +// any files found. It reads the layer digest metadata stored alongside +// the volume status to map files back to their digests. +func (lc *LayerCache) rebuildVolume(volumeDir string, sm *status.StatusManager) int { + statusPath := filepath.Join(volumeDir, "status.json") + s, err := sm.Get(statusPath) + if err != nil { + return 0 + } + + // Only rebuild from successfully pulled volumes. + if s.State != status.StatePullSucceeded && s.State != status.StateMounted { + return 0 + } + + // Read the layer digest metadata file. + metadataPath := filepath.Join(volumeDir, "layer_digests.json") + metadata, err := loadLayerMetadata(metadataPath) + if err != nil { + return 0 + } + + registered := 0 + modelDir := filepath.Join(volumeDir, "model") + for _, entry := range metadata { + filePath := filepath.Join(modelDir, entry.FilePath) + if _, err := os.Stat(filePath); err != nil { + continue // file no longer exists, skip + } + d, err := digest.Parse(entry.Digest) + if err != nil { + continue + } + lc.Register(d, filePath) + registered++ + } + + return registered +} diff --git a/pkg/service/layer_cache_test.go b/pkg/service/layer_cache_test.go new file mode 100644 index 0000000..a6f4b31 --- /dev/null +++ b/pkg/service/layer_cache_test.go @@ -0,0 +1,627 @@ +package service + +import ( + "context" + "os" + "path/filepath" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/status" + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/require" +) + +// ─── LayerCache basic ops ───────────────────────────────────────────────────── + +func TestLayerCache_RegisterAndLookup(t *testing.T) { + lc := NewLayerCache(8) + + d := digest.FromString("layer-content-1") + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "layer1.bin") + require.NoError(t, os.WriteFile(path, []byte("data"), 0644)) + + lc.Register(d, path) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, path, found) +} + +func TestLayerCache_LookupMissing(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("nonexistent") + _, ok := lc.Lookup(d) + require.False(t, ok) +} + +func TestLayerCache_LookupStaleFile(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("stale-content") + path := filepath.Join(t.TempDir(), "deleted.bin") + + // Register a path that doesn't exist on disk. + lc.Register(d, path) + + _, ok := lc.Lookup(d) + require.False(t, ok, "should not find stale path") +} + +func TestLayerCache_RegisterIdempotent(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("dup-content") + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "layer.bin") + require.NoError(t, os.WriteFile(path, []byte("data"), 0644)) + + lc.Register(d, path) + lc.Register(d, path) + lc.Register(d, path) + + require.Equal(t, 1, lc.PathCount()) +} + +func TestLayerCache_MultiplePaths(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("shared-content") + tmpDir := t.TempDir() + + path1 := filepath.Join(tmpDir, "vol1", "layer.bin") + path2 := filepath.Join(tmpDir, "vol2", "layer.bin") + require.NoError(t, os.MkdirAll(filepath.Dir(path1), 0755)) + require.NoError(t, os.MkdirAll(filepath.Dir(path2), 0755)) + require.NoError(t, os.WriteFile(path1, []byte("data"), 0644)) + require.NoError(t, os.WriteFile(path2, []byte("data"), 0644)) + + lc.Register(d, path1) + lc.Register(d, path2) + + require.Equal(t, 2, lc.PathCount()) + require.Equal(t, 1, lc.Len()) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Contains(t, []string{path1, path2}, found) +} + +// ─── LayerCache Remove ──────────────────────────────────────────────────────── + +func TestLayerCache_Remove(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("remove-test") + tmpDir := t.TempDir() + + path1 := filepath.Join(tmpDir, "layer1.bin") + path2 := filepath.Join(tmpDir, "layer2.bin") + require.NoError(t, os.WriteFile(path1, []byte("data"), 0644)) + require.NoError(t, os.WriteFile(path2, []byte("data"), 0644)) + + lc.Register(d, path1) + lc.Register(d, path2) + + lc.Remove(path1) + require.Equal(t, 1, lc.PathCount()) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, path2, found) +} + +func TestLayerCache_RemoveLastPath(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("remove-last") + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "only.bin") + require.NoError(t, os.WriteFile(path, []byte("data"), 0644)) + + lc.Register(d, path) + lc.Remove(path) + + require.Equal(t, 0, lc.Len()) + require.Equal(t, 0, lc.PathCount()) +} + +// ─── LayerCache RemoveByPrefix ──────────────────────────────────────────────── + +func TestLayerCache_RemoveByPrefix(t *testing.T) { + lc := NewLayerCache(8) + tmpDir := t.TempDir() + + // Create two volumes with layers. + d1 := digest.FromString("layer-a") + d2 := digest.FromString("layer-b") + d3 := digest.FromString("layer-c") + + vol1Path := filepath.Join(tmpDir, "vol1", "model", "a.bin") + vol1PathB := filepath.Join(tmpDir, "vol1", "model", "b.bin") + vol2Path := filepath.Join(tmpDir, "vol2", "model", "c.bin") + + for _, p := range []string{vol1Path, vol1PathB, vol2Path} { + require.NoError(t, os.MkdirAll(filepath.Dir(p), 0755)) + require.NoError(t, os.WriteFile(p, []byte("data"), 0644)) + } + + lc.Register(d1, vol1Path) + lc.Register(d2, vol1PathB) + lc.Register(d3, vol2Path) + + require.Equal(t, 3, lc.Len()) + + // Remove all entries under vol1. + lc.RemoveByPrefix(filepath.Join(tmpDir, "vol1")) + + require.Equal(t, 1, lc.Len()) + require.Equal(t, 1, lc.PathCount()) + + // vol2 entry should still be there. + found, ok := lc.Lookup(d3) + require.True(t, ok) + require.Equal(t, vol2Path, found) +} + +// ─── LayerCache Snapshot ────────────────────────────────────────────────────── + +func TestLayerCache_Snapshot(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("snap-content") + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "snap.bin") + require.NoError(t, os.WriteFile(path, []byte("data"), 0644)) + + lc.Register(d, path) + + snap := lc.Snapshot() + require.Len(t, snap, 1) + require.Equal(t, []string{path}, snap[d]) + + // Modifying snapshot should not affect the cache. + snap[d] = append(snap[d], "extra") + require.Equal(t, 1, lc.PathCount()) +} + +// ─── LayerCache Concurrency ─────────────────────────────────────────────────── + +func TestLayerCache_ConcurrentRegisterAndLookup(t *testing.T) { + lc := NewLayerCache(8) + tmpDir := t.TempDir() + + const n = 50 + digests := make([]digest.Digest, n) + paths := make([]string, n) + for i := 0; i < n; i++ { + digests[i] = digest.FromString(string(rune(i))) + paths[i] = filepath.Join(tmpDir, string(rune('a'+i))) + require.NoError(t, os.WriteFile(paths[i], []byte("data"), 0644)) + } + + var wg sync.WaitGroup + // Concurrently register. + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + lc.Register(digests[i], paths[i]) + }(i) + } + wg.Wait() + + // Concurrently lookup. + var found atomic.Int32 + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + if _, ok := lc.Lookup(digests[i]); ok { + found.Add(1) + } + }(i) + } + wg.Wait() + + require.Equal(t, int32(n), found.Load()) +} + +func TestLayerCache_ConcurrentSameLayerRegister(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("same-layer") + tmpDir := t.TempDir() + + const n = 20 + paths := make([]string, n) + for i := 0; i < n; i++ { + paths[i] = filepath.Join(tmpDir, string(rune('a'+i))) + require.NoError(t, os.WriteFile(paths[i], []byte("data"), 0644)) + } + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + lc.Register(d, paths[i]) + }(i) + } + wg.Wait() + + // All n paths should be registered under one digest. + require.Equal(t, 1, lc.Len()) + require.Equal(t, n, lc.PathCount()) +} + +// ─── LayerCache Rebuild ─────────────────────────────────────────────────────── + +func TestLayerCache_Rebuild(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + // Set up a static volume with a successful pull. + volumeName := "pvc-rebuild-test" + volumeDir := cfg.Get().GetVolumeDir(volumeName) + modelDir := cfg.Get().GetModelDir(volumeName) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + // Write status. + statusPath := filepath.Join(volumeDir, "status.json") + _, err = sm.Set(statusPath, status.Status{ + VolumeName: volumeName, + Reference: "registry/model:v1", + State: status.StatePullSucceeded, + }) + require.NoError(t, err) + + // Write a model file. + layerFile := filepath.Join(modelDir, "weights.bin") + require.NoError(t, os.WriteFile(layerFile, []byte("layer-data"), 0644)) + + // Write layer metadata. + d := digest.FromString("weights-content") + metadataPath := filepath.Join(volumeDir, "layer_digests.json") + require.NoError(t, saveLayerMetadata(metadataPath, []LayerMetadataEntry{ + {Digest: d.String(), FilePath: "weights.bin", Size: 10}, + })) + + // Rebuild. + lc := NewLayerCache(8) + lc.Rebuild(cfg.Get(), sm) + + require.Equal(t, 1, lc.Len()) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, layerFile, found) +} + +func TestLayerCache_RebuildSkipsFailedPull(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + volumeName := "pvc-failed-test" + volumeDir := cfg.Get().GetVolumeDir(volumeName) + modelDir := cfg.Get().GetModelDir(volumeName) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + statusPath := filepath.Join(volumeDir, "status.json") + _, err = sm.Set(statusPath, status.Status{ + VolumeName: volumeName, + Reference: "registry/model:v1", + State: status.StatePullFailed, + }) + require.NoError(t, err) + + lc := NewLayerCache(8) + lc.Rebuild(cfg.Get(), sm) + + require.Equal(t, 0, lc.Len(), "should not rebuild from failed pull") +} + +func TestLayerCache_RebuildDynamicVolume(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + volumeName := "csi-dyn-rebuild" + mountID := "mount-1" + mountDir := cfg.Get().GetMountIDDirForDynamic(volumeName, mountID) + modelDir := cfg.Get().GetModelDirForDynamic(volumeName, mountID) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + statusPath := filepath.Join(mountDir, "status.json") + _, err = sm.Set(statusPath, status.Status{ + Reference: "registry/model:dyn", + State: status.StateMounted, + }) + require.NoError(t, err) + + layerFile := filepath.Join(modelDir, "model.bin") + require.NoError(t, os.WriteFile(layerFile, []byte("dyn-data"), 0644)) + + d := digest.FromString("dyn-content") + metadataPath := filepath.Join(mountDir, "layer_digests.json") + require.NoError(t, saveLayerMetadata(metadataPath, []LayerMetadataEntry{ + {Digest: d.String(), FilePath: "model.bin", Size: 8}, + })) + + lc := NewLayerCache(8) + lc.Rebuild(cfg.Get(), sm) + + require.Equal(t, 1, lc.Len()) + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, layerFile, found) +} + +func TestLayerCache_RebuildSkipsMissingFile(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + volumeName := "pvc-missing-file" + volumeDir := cfg.Get().GetVolumeDir(volumeName) + modelDir := cfg.Get().GetModelDir(volumeName) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + statusPath := filepath.Join(volumeDir, "status.json") + _, err = sm.Set(statusPath, status.Status{ + VolumeName: volumeName, + Reference: "registry/model:v1", + State: status.StatePullSucceeded, + }) + require.NoError(t, err) + + d := digest.FromString("missing-file-content") + metadataPath := filepath.Join(volumeDir, "layer_digests.json") + require.NoError(t, saveLayerMetadata(metadataPath, []LayerMetadataEntry{ + {Digest: d.String(), FilePath: "does_not_exist.bin", Size: 10}, + })) + + lc := NewLayerCache(8) + lc.Rebuild(cfg.Get(), sm) + + require.Equal(t, 0, lc.Len(), "should skip missing files during rebuild") +} + +// ─── LayerCache Cleanup Integration ─────────────────────────────────────────── + +func TestLayerCache_CleanupOnDeleteModel(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + lc := NewLayerCache(8) + + worker, err := NewWorkerWithLayerCache(cfg, sm, lc) + require.NoError(t, err) + + // Simulate a pulled volume with cached layers. + volumeName := "pvc-cleanup-test" + modelDir := cfg.Get().GetModelDir(volumeName) + require.NoError(t, os.MkdirAll(modelDir, 0755)) + + d := digest.FromString("cleanup-content") + layerFile := filepath.Join(modelDir, "weights.bin") + require.NoError(t, os.WriteFile(layerFile, []byte("data"), 0644)) + lc.Register(d, layerFile) + + require.Equal(t, 1, lc.Len()) + + // Delete the volume. + err = worker.DeleteModel(context.Background(), true, volumeName, "") + require.NoError(t, err) + + // Layer cache should be cleaned. + require.Equal(t, 0, lc.Len(), "layer cache should be empty after delete") +} + +// ─── Hardlink fallback tests ────────────────────────────────────────────────── + +func TestLayerCache_HardlinkFallbackOnMissingSource(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("hardlink-source-missing") + + // Register a path that won't exist on disk. + lc.Register(d, "/nonexistent/path/layer.bin") + + // Lookup should fail (source doesn't exist). + _, ok := lc.Lookup(d) + require.False(t, ok, "should not find source that doesn't exist") +} + +func TestLayerCache_HardlinkFallbackUsesSecondPath(t *testing.T) { + lc := NewLayerCache(8) + d := digest.FromString("multi-path-fallback") + tmpDir := t.TempDir() + + // First path doesn't exist, second does. + path1 := filepath.Join(tmpDir, "deleted.bin") + path2 := filepath.Join(tmpDir, "exists.bin") + require.NoError(t, os.WriteFile(path2, []byte("data"), 0644)) + + lc.Register(d, path1) + lc.Register(d, path2) + + found, ok := lc.Lookup(d) + require.True(t, ok) + require.Equal(t, path2, found) +} + +// ─── Layer metadata persistence ─────────────────────────────────────────────── + +func TestLayerMetadata_SaveAndLoad(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "layer_digests.json") + + entries := []LayerMetadataEntry{ + {Digest: "sha256:abc", FilePath: "weights.bin", Size: 100}, + {Digest: "sha256:def", FilePath: "config.json", Size: 50}, + } + + require.NoError(t, saveLayerMetadata(path, entries)) + + loaded, err := loadLayerMetadata(path) + require.NoError(t, err) + require.Equal(t, entries, loaded) +} + +func TestLayerMetadata_LoadMissing(t *testing.T) { + _, err := loadLayerMetadata("/nonexistent/path") + require.Error(t, err) +} + +// ─── Semaphore / flow control ───────────────────────────────────────────────── + +func TestLayerCache_SemaphoreBasic(t *testing.T) { + lc := NewLayerCache(2) + sem := lc.Semaphore() + require.NotNil(t, sem) + + // Should be able to acquire 2 slots. + require.True(t, sem.TryAcquire(1)) + require.True(t, sem.TryAcquire(1)) + // Third should fail (limit is 2). + require.False(t, sem.TryAcquire(1)) + + sem.Release(1) + require.True(t, sem.TryAcquire(1)) +} + +func TestLayerCache_DefaultConcurrency(t *testing.T) { + lc := NewLayerCache(0) // Should use default. + sem := lc.Semaphore() + require.NotNil(t, sem) + + // Default is 8, should be able to acquire 8. + for i := 0; i < 8; i++ { + require.True(t, sem.TryAcquire(1)) + } + require.False(t, sem.TryAcquire(1)) +} + +// ─── Worker with LayerCache integration ─────────────────────────────────────── + +func TestNewWorkerWithLayerCache(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + lc := NewLayerCache(8) + + worker, err := NewWorkerWithLayerCache(cfg, sm, lc) + require.NoError(t, err) + require.NotNil(t, worker) + require.NotNil(t, worker.layerCache) +} + +func TestNewWorkerWithNilLayerCache(t *testing.T) { + tmpDir := t.TempDir() + rawCfg := &config.RawConfig{ServiceName: "test", RootDir: tmpDir} + cfg := config.NewWithRaw(rawCfg) + sm, err := status.NewStatusManager() + require.NoError(t, err) + + // Should use the standard puller when lc is nil. + worker, err := NewWorkerWithLayerCache(cfg, sm, nil) + require.NoError(t, err) + require.NotNil(t, worker) + require.Nil(t, worker.layerCache) +} + +// ─── Concurrent dedup via singleflight ──────────────────────────────────────── + +func TestLayerCache_SingleflightDedup(t *testing.T) { + lc := NewLayerCache(8) + sfg := lc.SflightGroup() + + var callCount atomic.Int32 + + d := digest.FromString("singleflight-test") + key := d.String() + + const n = 10 + var started atomic.Int32 + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + started.Add(1) + // Wait until all goroutines have started. + for started.Load() < n { + runtime.Gosched() + } + _, _, _ = sfg.Do(key, func() (interface{}, error) { + callCount.Add(1) + // Sleep to keep the function in-flight while other goroutines arrive. + time.Sleep(50 * time.Millisecond) + return nil, nil + }) + }() + } + + wg.Wait() + + // Due to singleflight, the function should be called exactly once + // because all goroutines overlap on the same key while blocked. + require.Equal(t, int32(1), callCount.Load(), + "singleflight should deduplicate concurrent calls") +} + +// ─── getLayerFilePath ───────────────────────────────────────────────────────── + +func TestGetLayerFilePath_NoAnnotations(t *testing.T) { + desc := createTestDescriptor("sha256:abc123", nil) + require.Equal(t, "", getLayerFilePath(desc)) +} + +func TestGetLayerFilePath_CurrentSpec(t *testing.T) { + desc := createTestDescriptor("sha256:abc123", map[string]string{ + "org.cncf.model.filepath": "weights/model.safetensors", + }) + require.Equal(t, "weights/model.safetensors", getLayerFilePath(desc)) +} + +func TestGetLayerFilePath_LegacySpec(t *testing.T) { + desc := createTestDescriptor("sha256:abc123", map[string]string{ + "org.cnai.model.filepath": "config.json", + }) + require.Equal(t, "config.json", getLayerFilePath(desc)) +} + +func TestGetLayerFilePath_PrefersCurrentOverLegacy(t *testing.T) { + desc := createTestDescriptor("sha256:abc123", map[string]string{ + "org.cncf.model.filepath": "current.bin", + "org.cnai.model.filepath": "legacy.bin", + }) + require.Equal(t, "current.bin", getLayerFilePath(desc)) +} + +// ─── Test helpers ───────────────────────────────────────────────────────────── + +func createTestDescriptor(digestStr string, annotations map[string]string) ocispec.Descriptor { + d, _ := digest.Parse(digestStr) + return ocispec.Descriptor{ + Digest: d, + Size: 100, + Annotations: annotations, + } +} diff --git a/pkg/service/layer_metadata.go b/pkg/service/layer_metadata.go new file mode 100644 index 0000000..040924f --- /dev/null +++ b/pkg/service/layer_metadata.go @@ -0,0 +1,46 @@ +package service + +import ( + "encoding/json" + "os" + + "github.com/pkg/errors" +) + +// LayerMetadataEntry stores the digest and file path for a single layer, +// persisted alongside the volume so the LayerCache can be rebuilt on restart. +type LayerMetadataEntry struct { + Digest string `json:"digest"` + FilePath string `json:"file_path"` + Size int64 `json:"size"` +} + +// saveLayerMetadata writes the layer metadata to the given path. +func saveLayerMetadata(path string, entries []LayerMetadataEntry) error { + data, err := json.MarshalIndent(entries, "", " ") + if err != nil { + return errors.Wrap(err, "marshal layer metadata") + } + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0644); err != nil { + return errors.Wrap(err, "write temp layer metadata") + } + if err := os.Rename(tmpPath, path); err != nil { + _ = os.Remove(tmpPath) // clean up on failure + return errors.Wrap(err, "rename layer metadata") + } + return nil +} + +// loadLayerMetadata reads the layer metadata from the given path. +func loadLayerMetadata(path string) ([]LayerMetadataEntry, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, errors.Wrap(err, "read layer metadata") + } + var entries []LayerMetadataEntry + if err := json.Unmarshal(data, &entries); err != nil { + return nil, errors.Wrap(err, "unmarshal layer metadata") + } + return entries, nil +} diff --git a/pkg/service/puller.go b/pkg/service/puller.go index 3e23f30..753e4cb 100644 --- a/pkg/service/puller.go +++ b/pkg/service/puller.go @@ -2,21 +2,31 @@ package service import ( "context" + "encoding/json" "io" "os" + "path/filepath" "strings" + "time" + oldModelspec "github.com/dragonflyoss/model-spec/specs-go/v1" "github.com/modelpack/modctl/pkg/backend" + "github.com/modelpack/modctl/pkg/backend/remote" + pkgcodec "github.com/modelpack/modctl/pkg/codec" modctlConfig "github.com/modelpack/modctl/pkg/config" "github.com/modelpack/model-csi-driver/pkg/config" "github.com/modelpack/model-csi-driver/pkg/config/auth" "github.com/modelpack/model-csi-driver/pkg/logger" "github.com/modelpack/model-csi-driver/pkg/status" + modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" ) type PullHook interface { + LayerCached(desc ocispec.Descriptor, manifest ocispec.Manifest) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) AfterPullLayer(desc ocispec.Descriptor, err error) } @@ -33,6 +43,18 @@ var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *stat } } +// NewLayerAwarePuller creates a puller that leverages the LayerCache for +// layer-level deduplication. It resolves the manifest ahead of time, +// hardlinks cached layers, and only pulls missing layers. +var NewLayerAwarePuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker, lc *LayerCache) Puller { + return &layerAwarePuller{ + pullCfg: pullCfg, + hook: hook, + diskQuotaChecker: diskQuotaChecker, + layerCache: lc, + } +} + type puller struct { pullCfg *config.PullConfig hook *status.Hook @@ -119,3 +141,436 @@ func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeM return nil } + +// layerAwarePuller wraps the pull flow with layer-level deduplication. +// It resolves the manifest first, checks the LayerCache for each layer, +// hardlinks cached layers, and only pulls missing ones via the Fetch path. +type layerAwarePuller struct { + pullCfg *config.PullConfig + hook *status.Hook + diskQuotaChecker *DiskQuotaChecker + layerCache *LayerCache +} + +func (p *layerAwarePuller) Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool, excludeFilePatterns []string) error { + keyChain, err := auth.GetKeyChainByRef(reference) + if err != nil { + return errors.Wrapf(err, "get auth for model: %s", reference) + } + plainHTTP := keyChain.ServerScheme == "http" + + b, err := backend.New("") + if err != nil { + return errors.Wrap(err, "create modctl backend") + } + + modelArtifact := NewModelArtifact(b, reference, plainHTTP) + + if p.diskQuotaChecker != nil { + if err := p.diskQuotaChecker.Check(ctx, modelArtifact, excludeModelWeights, excludeFilePatterns); err != nil { + return errors.Wrap(err, "check disk quota") + } + } + + if err := os.MkdirAll(targetDir, 0755); err != nil { + return errors.Wrapf(err, "create model dir: %s", targetDir) + } + + // If filtering is active, fall back to the standard puller (no dedup + // for partial pulls since layer content is subset-based). + if excludeModelWeights || len(excludeFilePatterns) > 0 { + return p.pullWithFiltering(ctx, b, modelArtifact, reference, targetDir, plainHTTP, excludeModelWeights, excludeFilePatterns) + } + + // Inspect the model to get layer descriptors. + artifact, err := modelArtifact.Inspect(ctx, reference) + if err != nil { + logger.WithContext(ctx).WithError(err).Warnf("layer-aware pull: failed to inspect model, falling back to standard pull") + return p.standardPull(ctx, b, reference, targetDir, plainHTTP) + } + + return p.layerAwarePull(ctx, b, artifact, reference, targetDir, plainHTTP) +} + +// layerAwarePull performs the deduplication-aware pull flow. +func (p *layerAwarePuller) layerAwarePull( + ctx context.Context, + b backend.Backend, + artifact *backend.InspectedModelArtifact, + reference, targetDir string, + plainHTTP bool, +) error { + if p.layerCache == nil { + return p.standardPull(ctx, b, reference, targetDir, plainHTTP) + } + + // 1. Resolve remote manifest directly to get full descriptors. + ref, err := backend.ParseReference(reference) + if err != nil { + return errors.Wrap(err, "parse reference") + } + + client, err := remote.New(ref.Repository(), remote.WithPlainHTTP(plainHTTP), remote.WithInsecure(true)) + if err != nil { + return errors.Wrap(err, "create remote client") + } + + _, manifestReader, err := client.Manifests().FetchReference(ctx, ref.Tag()) + if err != nil { + return errors.Wrap(err, "fetch manifest") + } + defer func() { _ = manifestReader.Close() }() + + var manifest ocispec.Manifest + if err := json.NewDecoder(manifestReader).Decode(&manifest); err != nil { + return errors.Wrap(err, "decode manifest") + } + + // 2. Classify layers as cached or uncached. + var uncachedLayers []ocispec.Descriptor + var cachedLayers []cachedLayerInfo + var metadataEntries []LayerMetadataEntry + + for _, layer := range manifest.Layers { + fp := getLayerFilePath(layer) + if fp == "" { + continue // skip layers without filepaths + } + + metadataEntries = append(metadataEntries, LayerMetadataEntry{ + Digest: layer.Digest.String(), + FilePath: fp, + Size: layer.Size, + }) + + sourcePath, found := p.layerCache.Lookup(layer.Digest) + if found { + cachedLayers = append(cachedLayers, cachedLayerInfo{ + desc: layer, + sourcePath: sourcePath, + filePath: fp, + }) + } else { + uncachedLayers = append(uncachedLayers, layer) + } + } + + logger.WithContext(ctx).Infof( + "layer-aware pull: %s — %d cached, %d to pull (total %d layers)", + reference, len(cachedLayers), len(uncachedLayers), len(manifest.Layers), + ) + + p.hook.SetTotal(len(manifest.Layers)) + + // Step 3: Hardlink cached layers. + for _, cl := range cachedLayers { + destPath := filepath.Join(targetDir, cl.filePath) + if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil { + logger.WithContext(ctx).WithError(err).Warnf( + "layer-aware pull: failed to create dir for hardlink %s, will pull instead", cl.filePath, + ) + uncachedLayers = append(uncachedLayers, cl.desc) + continue + } + + if err := os.Link(cl.sourcePath, destPath); err != nil { + logger.WithContext(ctx).WithError(err).Warnf( + "layer-aware pull: hardlink failed for %s (src=%s), will pull instead", + cl.filePath, cl.sourcePath, + ) + uncachedLayers = append(uncachedLayers, cl.desc) + continue + } + + logger.WithContext(ctx).Infof( + "layer-aware pull: reused layer %s via hardlink (%s → %s)", + cl.desc.Digest, cl.sourcePath, destPath, + ) + + // Immediately register the new path so others can use it. + p.layerCache.Register(cl.desc.Digest, destPath) + // Record progress silently for cached layers. + p.hook.LayerCached(cl.desc, manifest) + } + + // Step 4: Pull uncached layers concurrently, synchronized by singleflight. + if len(uncachedLayers) > 0 { + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(int(p.pullCfg.Concurrency)) + sfg := p.layerCache.SflightGroup() + sem := p.layerCache.Semaphore() + + for _, layerDesc := range uncachedLayers { + desc := layerDesc // capture for goroutine + g.Go(func() error { + for attempt := 0; attempt < 3; attempt++ { + select { + case <-gctx.Done(): + return gctx.Err() + default: + } + + // The singleflight key is the digest, guaranteeing only one download across all pods. + _, sfgErr, _ := sfg.Do(desc.Digest.String(), func() (interface{}, error) { + // Apply node-level flow control securely around network operations. + if sem != nil { + if err := sem.Acquire(gctx, 1); err != nil { + return nil, err + } + defer sem.Release(1) + } + + p.hook.BeforePullLayer(desc, manifest) + + // Defer the after hook to guarantee it runs on panic or error. + var pullErr error + defer func() { + p.hook.AfterPullLayer(desc, pullErr) + }() + + pullErr = func() error { + // Open network stream to registry. + reader, err := client.Fetch(gctx, desc) + if err != nil { + return errors.Wrap(err, "fetch blob from remote") + } + defer func() { _ = reader.Close() }() + + // Create codec to decode stream to disk. + codec, err := pkgcodec.New(pkgcodec.TypeFromMediaType(desc.MediaType)) + if err != nil { + return errors.Wrapf(err, "create codec for media type %s", desc.MediaType) + } + + fp := getLayerFilePath(desc) + if err := codec.Decode(targetDir, fp, reader, desc); err != nil { + // Check if another concurrent process (outside our driver) wrote it. + if errors.Is(err, pkgcodec.ErrAlreadyUpToDate) { + return nil + } + return errors.Wrap(err, "decode layer") + } + + // Registration only occurs on successful completion. + fullPath := filepath.Join(targetDir, fp) + if _, statErr := os.Stat(fullPath); statErr == nil { + p.layerCache.Register(desc.Digest, fullPath) + } + + return nil + }() + + return nil, pullErr + }) + + // If the singleflight function returned an error, retry the loop. + if sfgErr != nil { + logger.WithContext(ctx).WithError(sfgErr).Warnf("layer-aware pull: network fetch failed for %s, retrying", desc.Digest) + time.Sleep(1 * time.Second) + continue + } + + // If another goroutine successfully downloaded the file via singleflight, + // we still need to hardlink it into OUR target directory since we bypassed + // our own download logic. + fp := getLayerFilePath(desc) + destPath := filepath.Join(targetDir, fp) + + // Verify if the file is already at our destination (e.g. if WE were the downloader). + if _, statErr := os.Stat(destPath); statErr == nil { + return nil // We did the download, or it's already there. + } + + // We were a waiting caller. We must hardlink from the newly cached location. + sourcePath, found := p.layerCache.Lookup(desc.Digest) + if !found { + logger.WithContext(ctx).Warnf("layer-aware pull: singleflight completed but digest %s not in cache, retrying", desc.Digest) + time.Sleep(1 * time.Second) + continue + } + + if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil { + return err + } + if err := os.Link(sourcePath, destPath); err != nil { + // Hardlink failed (EXDEV or source was deleted). Remove stale entry and retry fetch. + logger.WithContext(ctx).WithError(err).Warnf("layer-aware pull: hardlink after singleflight failed for %s, retrying", desc.Digest) + p.layerCache.RemoveByPrefix(sourcePath) + time.Sleep(1 * time.Second) + continue + } + p.layerCache.Register(desc.Digest, destPath) + + // Record progress silently for waiting pods. + p.hook.LayerCached(desc, manifest) + + return nil + } + return errors.Errorf("layer-aware pull: failed to pull layer %s after 3 attempts", desc.Digest) + }) + } + + if err := g.Wait(); err != nil { + logger.WithContext(ctx).WithError(err).Errorf("layer-aware pull: failed to fetch uncached layers for %s", reference) + return err + } + } + + // Step 5: Save layer metadata for restart rebuild. + if len(metadataEntries) > 0 { + metadataPath := filepath.Join(filepath.Dir(targetDir), "layer_digests.json") + if err := saveLayerMetadata(metadataPath, metadataEntries); err != nil { + logger.WithContext(ctx).WithError(err).Warnf("layer-aware pull: failed to save layer metadata") + } + } + + return nil +} + +// standardPull falls back to the normal full pull path. +func (p *layerAwarePuller) standardPull(ctx context.Context, b backend.Backend, reference, targetDir string, plainHTTP bool) error { + trackingHook := &layerTrackingHook{ + inner: p.hook, + cache: p.layerCache, + targetDir: targetDir, + sem: p.layerCache.Semaphore(), + } + + pullConfig := modctlConfig.NewPull() + pullConfig.Concurrency = int(p.pullCfg.Concurrency) + pullConfig.PlainHTTP = plainHTTP + pullConfig.Proxy = p.pullCfg.ProxyURL + pullConfig.DragonflyEndpoint = p.pullCfg.DragonflyEndpoint + pullConfig.Insecure = true + pullConfig.ExtractDir = targetDir + pullConfig.ExtractFromRemote = true + pullConfig.Hooks = trackingHook + pullConfig.ProgressWriter = io.Discard + pullConfig.DisableProgress = true + + if err := b.Pull(ctx, reference, pullConfig); err != nil { + logger.WithContext(ctx).WithError(err).Errorf("failed to pull model image: %s", reference) + return errors.Wrap(err, "pull model image") + } + + return nil +} + +// pullWithFiltering handles pulls with weight/file-pattern exclusions. +func (p *layerAwarePuller) pullWithFiltering( + ctx context.Context, + b backend.Backend, + modelArtifact *ModelArtifact, + reference, targetDir string, + plainHTTP bool, + excludeModelWeights bool, + excludeFilePatterns []string, +) error { + patterns, total, err := modelArtifact.GetPatterns(ctx, excludeModelWeights, excludeFilePatterns) + if err != nil { + return errors.Wrap(err, "get model file patterns without weights") + } + + if len(patterns) == 0 { + logger.WithContext(ctx).Infof("no files to fetch from model: %s", reference) + return nil + } + + logger.WithContext(ctx).Infof( + "fetching partial files from model: %s, files: %s (%d/%d)", + reference, strings.Join(patterns, ", "), len(patterns), total, + ) + p.hook.SetTotal(len(patterns)) + + fetchConfig := modctlConfig.NewFetch() + fetchConfig.Concurrency = int(p.pullCfg.Concurrency) + fetchConfig.PlainHTTP = plainHTTP + fetchConfig.Proxy = p.pullCfg.ProxyURL + fetchConfig.DragonflyEndpoint = p.pullCfg.DragonflyEndpoint + fetchConfig.Insecure = true + fetchConfig.Output = targetDir + fetchConfig.Hooks = p.hook + fetchConfig.ProgressWriter = io.Discard + fetchConfig.DisableProgress = true + fetchConfig.Patterns = patterns + + if err := b.Fetch(ctx, reference, fetchConfig); err != nil { + logger.WithContext(ctx).WithError(err).Errorf("failed to fetch model: %s", reference) + return errors.Wrap(err, "fetch model") + } + + return nil +} + +// cachedLayerInfo holds info about a layer found in the cache. +type cachedLayerInfo struct { + desc ocispec.Descriptor + sourcePath string + filePath string +} + +// layerTrackingHook wraps the status.Hook to additionally register pulled +// layers in the LayerCache after successful pull, and enforce node-level +// concurrency via the semaphore. +type layerTrackingHook struct { + inner *status.Hook + cache *LayerCache + targetDir string + sem *semaphore.Weighted +} + +func (h *layerTrackingHook) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) { + // Enforce node-level flow control by acquiring a semaphore slot. + // We use 1 slot per layer (count-based, not size-based) to keep it simple + // and avoid potential issues with very large layers blocking all slots. + if h.sem != nil { + // Use a background context so we don't fail the pull if the parent + // context is cancelled while waiting — the pull itself handles cancellation. + _ = h.sem.Acquire(context.Background(), 1) + } + + h.inner.BeforePullLayer(desc, manifest) +} + +func (h *layerTrackingHook) AfterPullLayer(desc ocispec.Descriptor, err error) { + // Release the semaphore slot. + if h.sem != nil { + h.sem.Release(1) + } + + h.inner.AfterPullLayer(desc, err) + + // Only register on successful pull. + if err != nil || h.cache == nil { + return + } + + // Determine the file path from the layer descriptor annotations. + filePath := getLayerFilePath(desc) + if filePath == "" { + return + } + + fullPath := filepath.Join(h.targetDir, filePath) + if _, statErr := os.Stat(fullPath); statErr != nil { + return // file doesn't exist, cannot register + } + + h.cache.Register(desc.Digest, fullPath) +} + +// getLayerFilePath extracts the file path from a layer descriptor's annotations. +func getLayerFilePath(desc ocispec.Descriptor) string { + if desc.Annotations == nil { + return "" + } + // Try the current model-spec annotation first. + if fp := desc.Annotations[modelspec.AnnotationFilepath]; fp != "" { + return fp + } + // Fall back to legacy annotation. + if fp := desc.Annotations[oldModelspec.AnnotationFilepath]; fp != "" { + return fp + } + return "" +} diff --git a/pkg/service/service.go b/pkg/service/service.go index 73b0257..fc1899d 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -71,7 +71,13 @@ func New(cfg *config.Config) (*Service, error) { if err != nil { return nil, errors.Wrap(err, "create status manager") } - worker, err := NewWorker(cfg, sm) + + // Create the node-level layer cache for layer deduplication. + lc := NewLayerCache(defaultMaxConcurrentLayers) + // Rebuild cache from existing volumes on disk (handles restarts). + lc.Rebuild(cfg.Get(), sm) + + worker, err := NewWorkerWithLayerCache(cfg, sm, lc) if err != nil { return nil, errors.Wrap(err, "create worker") } diff --git a/pkg/service/worker.go b/pkg/service/worker.go index 8628736..98a895d 100644 --- a/pkg/service/worker.go +++ b/pkg/service/worker.go @@ -57,17 +57,32 @@ type Worker struct { inflight singleflight.Group contextMap *ContextMap kmutex kmutex.KeyedLocker + layerCache *LayerCache } func NewWorker(cfg *config.Config, sm *status.StatusManager) (*Worker, error) { - return &Worker{ + return NewWorkerWithLayerCache(cfg, sm, nil) +} + +func NewWorkerWithLayerCache(cfg *config.Config, sm *status.StatusManager, lc *LayerCache) (*Worker, error) { + w := &Worker{ cfg: cfg, newPuller: NewPuller, sm: sm, inflight: singleflight.Group{}, contextMap: NewContextMap(), kmutex: kmutex.New(), - }, nil + layerCache: lc, + } + + // When a LayerCache is provided, use the layer-aware puller. + if lc != nil { + w.newPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker) Puller { + return NewLayerAwarePuller(ctx, pullCfg, hook, diskQuotaChecker, lc) + } + } + + return w, nil } func (worker *Worker) deleteModel(ctx context.Context, isStaticVolume bool, volumeName, mountID string) error { @@ -103,6 +118,12 @@ func (worker *Worker) deleteModel(ctx context.Context, isStaticVolume bool, volu statusPath := filepath.Join(volumeDir, "status.json") worker.sm.HookManager.Delete(statusPath) + // Evict stale layer cache entries for the removed volume. + if worker.layerCache != nil { + worker.layerCache.RemoveByPrefix(volumeDir) + logger.WithContext(ctx).Infof("evicted layer cache entries for %s", volumeDir) + } + return nil, nil }) diff --git a/pkg/status/hook.go b/pkg/status/hook.go index 9ff40ff..f56de75 100644 --- a/pkg/status/hook.go +++ b/pkg/status/hook.go @@ -171,6 +171,41 @@ func (h *Hook) AfterPullLayer(desc ocispec.Descriptor, err error) { progress.Span.End() } +// LayerCached records a layer as successfully pulled without emitting +// duplicate network metrics or "pulled layer" logs. This ensures status totals +// remain accurate for waiting callers without spamming observability. +func (h *Hook) LayerCached(desc ocispec.Descriptor, manifest ocispec.Manifest) { + h.mutex.Lock() + defer h.mutex.Unlock() + + filePath := "" + if desc.Annotations != nil { + if desc.Annotations[modelspec.AnnotationFilepath] != "" { + filePath = fmt.Sprintf("/%s", desc.Annotations[modelspec.AnnotationFilepath]) + } else if desc.Annotations[oldModelspec.AnnotationFilepath] != "" { + filePath = fmt.Sprintf("/%s", desc.Annotations[oldModelspec.AnnotationFilepath]) + } + } + + h.manifest = &manifest + + _, span := tracing.Tracer.Start(h.ctx, "PullLayerCached") + span.End() + + now := time.Now() + h.progress[desc.Digest] = &ProgressItem{ + Digest: desc.Digest, + Path: filePath, + Size: desc.Size, + StartedAt: now, + FinishedAt: &now, + Error: nil, + Span: span, + } + + h.pulled.Add(1) +} + func (h *Hook) getProgress() Progress { items := []ProgressItem{} for _, item := range h.progress {