diff --git a/pkg/backend/fetch.go b/pkg/backend/fetch.go index e56794c5..a7e927a5 100644 --- a/pkg/backend/fetch.go +++ b/pkg/backend/fetch.go @@ -38,6 +38,12 @@ import ( func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) error { logrus.Infof("fetch: fetching from %s", target) + // Apply default hooks when caller leaves it unset to avoid nil deref. + if cfg.Hooks == nil { + defaults := config.NewFetch() + cfg.Hooks = defaults.Hooks + } + // fetchByDragonfly is called if a Dragonfly endpoint is specified in the configuration. if cfg.DragonflyEndpoint != "" { logrus.Infof("fetch: using dragonfly for %s", target) @@ -117,11 +123,19 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e } logrus.Debugf("fetch: processing layer %s", layer.Digest) + if cfg.Hooks.BeforePullLayer(layer, manifest) { + logrus.Debugf("fetch: layer %s skipped by hook", layer.Digest) + pb.Complete(layer.Digest.String(), fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Skipped blob"), layer.Digest.String())) + cfg.Hooks.AfterPullLayer(layer, true, nil) + return nil + } if err := tracker.TrackTransfer(func() error { return pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Fetching blob"), client, cfg.Output, layer, tracker) }); err != nil { + cfg.Hooks.AfterPullLayer(layer, false, err) return err } + cfg.Hooks.AfterPullLayer(layer, false, nil) logrus.Debugf("fetch: successfully processed layer %s", layer.Digest) return nil diff --git a/pkg/backend/fetch_by_d7y.go b/pkg/backend/fetch_by_d7y.go index 76ec9bdb..d387296e 100644 --- a/pkg/backend/fetch_by_d7y.go +++ b/pkg/backend/fetch_by_d7y.go @@ -157,9 +157,14 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Fetch) error { err := retry.Do(func() error { logrus.Debugf("fetch: processing layer %s", desc.Digest) - cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook + if cfg.Hooks.BeforePullLayer(desc, manifest) { + logrus.Debugf("fetch: layer %s skipped by hook", desc.Digest) + pb.Complete(desc.Digest.String(), fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Skipped blob"), desc.Digest.String())) + cfg.Hooks.AfterPullLayer(desc, true, nil) + return nil + } err := downloadAndExtractFetchLayer(ctx, pb, client, ref, desc, authToken, cfg) - cfg.Hooks.AfterPullLayer(desc, err) // Call after hook + cfg.Hooks.AfterPullLayer(desc, false, err) // Call after hook if err != nil { err = fmt.Errorf("pull: failed to download and extract layer %s: %w", desc.Digest, err) logrus.Error(err) diff --git a/pkg/backend/fetch_hooks_test.go b/pkg/backend/fetch_hooks_test.go new file mode 100644 index 00000000..ae8889cf --- /dev/null +++ b/pkg/backend/fetch_hooks_test.go @@ -0,0 +1,182 @@ +/* + * Copyright 2025 The ModelPack Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package backend + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "sync/atomic" + "testing" + + modelspec "github.com/modelpack/model-spec/specs-go/v1" + godigest "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/modelpack/modctl/pkg/config" +) + +// recordingFetchHook tracks hook invocations and can request specific layers +// to be skipped by digest. +type recordingFetchHook struct { + mu sync.Mutex + skipDigests map[string]bool + beforeCount int32 + afterCalls []afterFetchCall +} + +type afterFetchCall struct { + digest string + skipped bool + err error +} + +func (r *recordingFetchHook) BeforePullLayer(desc ocispec.Descriptor, _ ocispec.Manifest) bool { + atomic.AddInt32(&r.beforeCount, 1) + r.mu.Lock() + defer r.mu.Unlock() + return r.skipDigests[desc.Digest.String()] +} + +func (r *recordingFetchHook) AfterPullLayer(desc ocispec.Descriptor, skipped bool, err error) { + r.mu.Lock() + defer r.mu.Unlock() + r.afterCalls = append(r.afterCalls, afterFetchCall{ + digest: desc.Digest.String(), + skipped: skipped, + err: err, + }) +} + +// startFetchTestServer spins up an HTTP server that serves a manifest with +// two layers and tracks how many times each blob is requested. +func startFetchTestServer(t *testing.T) (server *httptest.Server, file1Digest, file2Digest godigest.Digest, blobHits map[string]*int32) { + t.Helper() + + const ( + file1Content = "file1 content..." + file2Content = "file2 content..." + ) + file1Digest = godigest.FromString(file1Content) + file2Digest = godigest.FromString(file2Content) + + hits := map[string]*int32{ + file1Digest.String(): new(int32), + file2Digest.String(): new(int32), + } + + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v2/": + w.WriteHeader(http.StatusOK) + case "/v2/test/model/manifests/latest": + manifest := ocispec.Manifest{ + Layers: []ocispec.Descriptor{ + { + MediaType: "application/octet-stream.raw", + Digest: file1Digest, + Size: int64(len(file1Content)), + Annotations: map[string]string{ + modelspec.AnnotationFilepath: "file1.txt", + }, + }, + { + MediaType: "application/octet-stream.raw", + Digest: file2Digest, + Size: int64(len(file2Content)), + Annotations: map[string]string{ + modelspec.AnnotationFilepath: "file2.txt", + }, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(manifest)) + case fmt.Sprintf("/v2/test/model/blobs/%s", file1Digest): + atomic.AddInt32(hits[file1Digest.String()], 1) + _, err := w.Write([]byte(file1Content)) + require.NoError(t, err) + case fmt.Sprintf("/v2/test/model/blobs/%s", file2Digest): + atomic.AddInt32(hits[file2Digest.String()], 1) + _, err := w.Write([]byte(file2Content)) + require.NoError(t, err) + default: + t.Logf("Unexpected request to %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + + return server, file1Digest, file2Digest, hits +} + +// TestFetch_HookSkipShortCircuitsLayer verifies that returning skip=true from +// BeforePullLayer prevents the blob from being downloaded and that +// AfterPullLayer is still invoked with skipped=true. +func TestFetch_HookSkipShortCircuitsLayer(t *testing.T) { + tempDir, err := os.MkdirTemp("", "fetch-hook-test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + server, file1Digest, file2Digest, hits := startFetchTestServer(t) + defer server.Close() + + hook := &recordingFetchHook{ + skipDigests: map[string]bool{file1Digest.String(): true}, + } + + b := &backend{} + url := strings.TrimPrefix(server.URL, "http://") + cfg := &config.Fetch{ + Output: tempDir, + Patterns: []string{"*.txt"}, + PlainHTTP: true, + Concurrency: 2, + Hooks: hook, + } + + require.NoError(t, b.Fetch(context.Background(), url+"/test/model:latest", cfg)) + + // file1 must NOT have been downloaded; file2 must have been. + assert.Equal(t, int32(0), atomic.LoadInt32(hits[file1Digest.String()]), + "skipped layer should not be fetched from remote") + assert.Equal(t, int32(1), atomic.LoadInt32(hits[file2Digest.String()]), + "non-skipped layer should be fetched once") + + // BeforePullLayer fires for both layers exactly once (no retries on success). + assert.Equal(t, int32(2), atomic.LoadInt32(&hook.beforeCount)) + + // AfterPullLayer must be invoked for both layers, with proper skipped flag. + hook.mu.Lock() + defer hook.mu.Unlock() + require.Len(t, hook.afterCalls, 2) + + byDigest := map[string]afterFetchCall{} + for _, c := range hook.afterCalls { + byDigest[c.digest] = c + } + assert.True(t, byDigest[file1Digest.String()].skipped, "file1 should be marked skipped") + assert.NoError(t, byDigest[file1Digest.String()].err) + assert.False(t, byDigest[file2Digest.String()].skipped, "file2 should not be marked skipped") + assert.NoError(t, byDigest[file2Digest.String()].err) +} diff --git a/pkg/backend/pull.go b/pkg/backend/pull.go index 777afa43..f96460a4 100644 --- a/pkg/backend/pull.go +++ b/pkg/backend/pull.go @@ -41,6 +41,12 @@ import ( func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) error { logrus.Infof("pull: pulling artifact %s", target) + // Apply default hooks when caller leaves it unset to avoid nil deref. + if cfg.Hooks == nil { + defaults := config.NewPull() + cfg.Hooks = defaults.Hooks + } + // pullByDragonfly is called if a Dragonfly endpoint is specified in the configuration. if cfg.DragonflyEndpoint != "" { logrus.Infof("pull: using dragonfly for %s", target) @@ -118,13 +124,18 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err return retry.Do(func() error { logrus.Debugf("pull: processing layer %s", layer.Digest) - // call the before hook. - cfg.Hooks.BeforePullLayer(layer, manifest) + // call the before hook; allow caller to skip this layer. + if cfg.Hooks.BeforePullLayer(layer, manifest) { + logrus.Debugf("pull: layer %s skipped by hook", layer.Digest) + pb.Complete(layer.Digest.String(), fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Skipped blob"), layer.Digest.String())) + cfg.Hooks.AfterPullLayer(layer, true, nil) + return nil + } err := tracker.TrackTransfer(func() error { return fn(layer) }) // call the after hook. - cfg.Hooks.AfterPullLayer(layer, err) + cfg.Hooks.AfterPullLayer(layer, false, err) if err != nil { err = fmt.Errorf("pull: failed to process layer %s: %w", layer.Digest, err) logrus.Error(err) diff --git a/pkg/backend/pull_by_d7y.go b/pkg/backend/pull_by_d7y.go index 87c8aff1..b5855d47 100644 --- a/pkg/backend/pull_by_d7y.go +++ b/pkg/backend/pull_by_d7y.go @@ -181,9 +181,14 @@ func buildBlobURL(ref Referencer, plainHTTP bool, digest string) string { func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Pull) error { err := retry.Do(func() error { logrus.Debugf("pull: processing layer %s", desc.Digest) - cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook + if cfg.Hooks.BeforePullLayer(desc, manifest) { + logrus.Debugf("pull: layer %s skipped by hook", desc.Digest) + pb.Complete(desc.Digest.String(), fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Skipped blob"), desc.Digest.String())) + cfg.Hooks.AfterPullLayer(desc, true, nil) + return nil + } err := downloadAndExtractLayer(ctx, pb, client, ref, desc, authToken, cfg) - cfg.Hooks.AfterPullLayer(desc, err) // Call after hook + cfg.Hooks.AfterPullLayer(desc, false, err) // Call after hook if err != nil { err = fmt.Errorf("pull: failed to download and extract layer %s: %w", desc.Digest, err) logrus.Error(err) diff --git a/pkg/config/pull.go b/pkg/config/pull.go index 1dfb6442..74c39d8f 100644 --- a/pkg/config/pull.go +++ b/pkg/config/pull.go @@ -78,16 +78,29 @@ func (p *Pull) Validate() error { } // PullHooks is the hook events during the pull operation. +// +// Note: every retry attempt re-invokes BeforePullLayer / AfterPullLayer. type PullHooks interface { - // BeforePullLayer will execute before pulling the layer described as desc, will carry the manifest as well. - BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) + // BeforePullLayer will execute before pulling the layer described as desc, + // will carry the manifest as well. + // + // If the hook returns skip=true, the backend will treat this layer as + // already satisfied and will NOT actually pull/extract it. The caller is + // responsible for ensuring the corresponding content already exists and + // matches the descriptor's digest. AfterPullLayer will still be invoked + // with skipped=true and a nil error. + BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) (skip bool) - // AfterPullLayer will execute after pulling the layer described as desc, the error will be nil if pulled successfully. - AfterPullLayer(desc ocispec.Descriptor, err error) + // AfterPullLayer will execute after pulling the layer described as desc. + // skipped indicates whether the layer was skipped by BeforePullLayer's + // decision. err will be nil if pulled (or skipped) successfully. + AfterPullLayer(desc ocispec.Descriptor, skipped bool, err error) } // emptyPullHook is the empty pull hook implementation with do nothing. type emptyPullHook struct{} -func (emptyPullHook) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) {} -func (emptyPullHook) AfterPullLayer(desc ocispec.Descriptor, err error) {} +func (emptyPullHook) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) bool { + return false +} +func (emptyPullHook) AfterPullLayer(desc ocispec.Descriptor, skipped bool, err error) {} diff --git a/pkg/config/pull_test.go b/pkg/config/pull_test.go new file mode 100644 index 00000000..9c9a891b --- /dev/null +++ b/pkg/config/pull_test.go @@ -0,0 +1,97 @@ +/* + * Copyright 2024 The ModelPack Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "errors" + "testing" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" +) + +func TestNewPull_DefaultsAndValidate(t *testing.T) { + p := NewPull() + assert.NotNil(t, p.Hooks, "default hooks must be non-nil") + assert.NoError(t, p.Validate()) + + // emptyPullHook never asks the backend to skip and accepts results + // without panicking. + desc := ocispec.Descriptor{Digest: "sha256:deadbeef"} + assert.False(t, p.Hooks.BeforePullLayer(desc, ocispec.Manifest{}), + "default hook must not skip layers") + p.Hooks.AfterPullLayer(desc, false, nil) + p.Hooks.AfterPullLayer(desc, true, nil) + p.Hooks.AfterPullLayer(desc, false, errors.New("boom")) +} + +// recordingHook is a PullHooks implementation used to verify the interface +// contract: BeforePullLayer can request a skip, and AfterPullLayer reports +// whether the layer was skipped along with any error. +type recordingHook struct { + skipDigests map[string]bool + + beforeCalls []ocispec.Descriptor + afterCalls []afterCall +} + +type afterCall struct { + desc ocispec.Descriptor + skipped bool + err error +} + +func (r *recordingHook) BeforePullLayer(desc ocispec.Descriptor, _ ocispec.Manifest) bool { + r.beforeCalls = append(r.beforeCalls, desc) + return r.skipDigests[desc.Digest.String()] +} + +func (r *recordingHook) AfterPullLayer(desc ocispec.Descriptor, skipped bool, err error) { + r.afterCalls = append(r.afterCalls, afterCall{desc: desc, skipped: skipped, err: err}) +} + +func TestPullHooks_InterfaceContract(t *testing.T) { + // Compile-time check that recordingHook satisfies the PullHooks interface. + var _ PullHooks = (*recordingHook)(nil) + + hook := &recordingHook{ + skipDigests: map[string]bool{"sha256:aaa": true}, + } + a := ocispec.Descriptor{Digest: "sha256:aaa"} + b := ocispec.Descriptor{Digest: "sha256:bbb"} + + assert.True(t, hook.BeforePullLayer(a, ocispec.Manifest{}), "aaa should be skipped") + assert.False(t, hook.BeforePullLayer(b, ocispec.Manifest{}), "bbb should not be skipped") + + hook.AfterPullLayer(a, true, nil) + wantErr := errors.New("network down") + hook.AfterPullLayer(b, false, wantErr) + + assert.Equal(t, []ocispec.Descriptor{a, b}, hook.beforeCalls) + assert.Equal(t, []afterCall{ + {desc: a, skipped: true, err: nil}, + {desc: b, skipped: false, err: wantErr}, + }, hook.afterCalls) +} + +func TestNewFetch_DefaultHooks(t *testing.T) { + f := NewFetch() + assert.NotNil(t, f.Hooks) + desc := ocispec.Descriptor{Digest: "sha256:cafe"} + assert.False(t, f.Hooks.BeforePullLayer(desc, ocispec.Manifest{})) + f.Hooks.AfterPullLayer(desc, false, nil) +}