From e2c7e29aa3487597c6443187d00ac714f772fa6c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 23 Jun 2026 23:58:52 +0000 Subject: [PATCH 1/8] Initial plan From 3f67b3820b48b7b690b4f8b49e959327e87b2f7a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 24 Jun 2026 00:15:23 +0000 Subject: [PATCH 2/8] Allow azd extension install to register source directly from a registry location Co-authored-by: JeffreyCA <9157833+JeffreyCA@users.noreply.github.com> --- cli/azd/CHANGELOG.md | 2 + cli/azd/cmd/extension.go | 167 ++++++++++++++- cli/azd/cmd/extension_install_source_test.go | 199 ++++++++++++++++++ cli/azd/cmd/testdata/TestFigSpec.ts | 2 +- .../TestUsage-azd-extension-install.snap | 2 +- .../docs/extensions/extension-framework.md | 8 +- 6 files changed, 376 insertions(+), 4 deletions(-) create mode 100644 cli/azd/cmd/extension_install_source_test.go diff --git a/cli/azd/CHANGELOG.md b/cli/azd/CHANGELOG.md index 2e227d98e57..e7ef3635362 100644 --- a/cli/azd/CHANGELOG.md +++ b/cli/azd/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- [[#8581]](https://github.com/Azure/azure-dev/issues/8581) `azd extension install -s ` now accepts a registry location (URL or file path) directly for `-s/--source`, registering it as a persisted source and installing in one step. azd prompts for the source name and confirms before registering an untrusted URL; under `--no-prompt` it directs you to add the source first with `azd extension source add`. + ### Breaking Changes ### Bugs Fixed diff --git a/cli/azd/cmd/extension.go b/cli/azd/cmd/extension.go index 1d57c423994..0ef40ac9248 100644 --- a/cli/azd/cmd/extension.go +++ b/cli/azd/cmd/extension.go @@ -13,6 +13,7 @@ import ( "log" "maps" "net" + "net/url" "os" "path/filepath" "slices" @@ -80,6 +81,12 @@ func extensionActions(root *actions.ActionDescriptor) *actions.ActionDescriptor Short: "Installs specified extensions.", Long: `Installs one or more extensions by id from a registered extension source. +The --source flag also accepts a registry location (a URL or file path) instead +of the name of a registered source. When a location is given, azd registers it as +a new source (prompting for a name, and confirming first for a URL) and then +installs from it. The source is persisted so later upgrade/list/show commands work +against it without re-specifying the location. + You can also pass the path to a self-contained extension bundle (.zip): azd extracts it and installs the bundled extension. Bundled extensions aren't tracked for updates; reinstall from a newer bundle to update.`, @@ -768,7 +775,9 @@ func newExtensionInstallFlags(cmd *cobra.Command, global *internal.GlobalCommand global: global, } - cmd.Flags().StringVarP(&flags.source, "source", "s", "", "The extension source to use for installs") + cmd.Flags().StringVarP(&flags.source, "source", "s", "", + "The extension source to use for installs. Accepts a registered source name "+ + "or a registry location (URL or file path) to register and install from.") cmd.Flags().StringVarP(&flags.version, "version", "v", "", "The version of the extension to install") cmd.Flags(). BoolVarP(&flags.force, "force", "f", false, "Force installation, including downgrades and reinstalls") @@ -850,6 +859,13 @@ func (a *extensionInstallAction) Run(ctx context.Context) (*actions.ActionResult } } + // If -s/--source points directly at a registry location (URL or file path) + // rather than an already-registered source name, register the source first so + // the install loop below can resolve extensions from it. + if err := a.resolveSourceLocation(ctx); err != nil { + return nil, err + } + azdVersion := currentAzdSemver() for index, extensionId := range extensionIds { @@ -1364,6 +1380,155 @@ func normalizeBundleSourceName(name string) string { return strings.Trim(sb.String(), "-") } +// resolveSourceLocation handles the case where -s/--source points directly at a +// registry location (URL or file path) rather than the name of an +// already-registered source. When a location is detected, it confirms before +// registering an untrusted URL, prompts for a source name, persists the source, +// and rewrites a.flags.source to the registered name so the install loop +// resolves extensions from it. Registered source names and values that do not +// look like a location are left unchanged. +func (a *extensionInstallAction) resolveSourceLocation(ctx context.Context) error { + if a.flags.source == "" { + return nil + } + + // If the value already names a registered source, keep current behavior. + _, err := a.sourceManager.Get(ctx, a.flags.source) + if err == nil { + return nil + } + if !errors.Is(err, extensions.ErrSourceNotFound) { + return fmt.Errorf("failed to resolve extension source %q: %w", a.flags.source, err) + } + + // Not a registered source — detect whether it is a registry location. + location := a.flags.source + kind, ok := inferSourceKind(location) + if !ok { + // Not a location; leave the value untouched so existing resolution and + // error messaging applies. + return nil + } + + // Registering a source is interactive (naming + trust confirmation), so in + // --no-prompt mode direct the user to add the source explicitly first. + if a.flags.global.NoPrompt { + return &internal.ErrorWithSuggestion{ + Err: fmt.Errorf( + "cannot register a new extension source from %q while --no-prompt is set", location), + Suggestion: fmt.Sprintf( + "Add the source first with %s, then install with %s.", + output.WithHighLightFormat( + "azd extension source add -n -t %s -l %q", kind, location), + output.WithHighLightFormat("azd extension install -s "), + ), + } + } + + // Confirm before registering a URL source, which may be untrusted. + if kind == extensions.SourceKindUrl { + a.console.Message(ctx, "") + confirm, err := a.console.Confirm(ctx, input.ConsoleOptions{ + Message: fmt.Sprintf( + "Register and install from the extension source at %s?", + output.WithHighLightFormat(location)), + DefaultValue: false, + }) + if err != nil { + return err + } + if !confirm { + return &internal.ErrorWithSuggestion{ + Err: errors.New("extension source registration declined"), + Suggestion: "Re-run and confirm to register the source, " + + "or add it explicitly with 'azd extension source add'.", + } + } + } + + // Prompt for a source name with a sensible default derived from the location. + defaultName := defaultSourceName(location) + sourceName, err := a.console.Prompt(ctx, input.ConsoleOptions{ + Message: "Enter a name for this extension source:", + DefaultValue: defaultName, + }) + if err != nil { + return err + } + sourceName = strings.TrimSpace(sourceName) + if sourceName == "" { + sourceName = defaultName + } + + sourceConfig := &extensions.SourceConfig{ + Name: sourceName, + Type: kind, + Location: location, + } + + spinnerMessage := fmt.Sprintf("Registering extension source %s", output.WithHighLightFormat(sourceName)) + a.console.ShowSpinner(ctx, spinnerMessage, input.Step) + + // Validate the source by hydrating it before persisting. + if _, err := a.sourceManager.CreateSource(ctx, sourceConfig); err != nil { + a.console.StopSpinner(ctx, spinnerMessage, input.StepFailed) + return fmt.Errorf("failed to validate extension source: %w", err) + } + + if err := a.sourceManager.Add(ctx, sourceName, sourceConfig); err != nil { + a.console.StopSpinner(ctx, spinnerMessage, input.StepFailed) + return fmt.Errorf("failed to add extension source: %w", err) + } + a.console.StopSpinner(ctx, spinnerMessage, input.StepDone) + + // Refresh manager caches so the new source is visible and the cached config + // snapshot is not clobbered when the install below saves. + a.extensionManager.InvalidateSourceCache() + if err := a.extensionManager.ReloadUserConfig(); err != nil { + return err + } + + // Add normalizes the persisted name; resolve extensions against that name. + a.flags.source = sourceConfig.Name + return nil +} + +// inferSourceKind infers the extension source kind from a registry location, +// matching the URL-vs-file heuristics used by 'azd extension source validate'. +// It reports false when the value does not look like a location and is more +// likely the name of a source. +func inferSourceKind(location string) (extensions.SourceKind, bool) { + if strings.HasPrefix(location, "http://") || strings.HasPrefix(location, "https://") { + return extensions.SourceKindUrl, true + } + if info, err := os.Stat(location); err == nil && !info.IsDir() { + return extensions.SourceKindFile, true + } + if strings.ContainsAny(location, `/\`) || strings.EqualFold(filepath.Ext(location), ".json") { + return extensions.SourceKindFile, true + } + return "", false +} + +// defaultSourceName derives a config-safe default extension source name from a +// registry location: the host for URLs, otherwise the file name without its +// extension. It falls back to "custom" when nothing usable can be derived. +func defaultSourceName(location string) string { + base := "" + if u, err := url.Parse(location); err == nil && u.Host != "" { + base = u.Host + } else { + base = filepath.Base(location) + base = strings.TrimSuffix(base, filepath.Ext(base)) + } + + name := normalizeBundleSourceName(base) + if name == "" { + name = "custom" + } + return name +} + // azd extension uninstall type extensionUninstallFlags struct { all bool diff --git a/cli/azd/cmd/extension_install_source_test.go b/cli/azd/cmd/extension_install_source_test.go new file mode 100644 index 00000000000..369c4abad00 --- /dev/null +++ b/cli/azd/cmd/extension_install_source_test.go @@ -0,0 +1,199 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/azure/azure-dev/cli/azd/internal" + "github.com/azure/azure-dev/cli/azd/pkg/extensions" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/test/mocks/mockinput" + "github.com/stretchr/testify/require" +) + +func TestInferSourceKind(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + existing := filepath.Join(dir, "registry.json") + require.NoError(t, os.WriteFile(existing, []byte("{}"), 0600)) + + t.Run("HttpUrl", func(t *testing.T) { + kind, ok := inferSourceKind("http://example.com/registry.json") + require.True(t, ok) + require.Equal(t, extensions.SourceKindUrl, kind) + }) + + t.Run("HttpsUrl", func(t *testing.T) { + kind, ok := inferSourceKind("https://example.com/registry.json") + require.True(t, ok) + require.Equal(t, extensions.SourceKindUrl, kind) + }) + + t.Run("ExistingFile", func(t *testing.T) { + kind, ok := inferSourceKind(existing) + require.True(t, ok) + require.Equal(t, extensions.SourceKindFile, kind) + }) + + t.Run("JsonExtension", func(t *testing.T) { + kind, ok := inferSourceKind("missing-registry.json") + require.True(t, ok) + require.Equal(t, extensions.SourceKindFile, kind) + }) + + t.Run("PathSeparator", func(t *testing.T) { + kind, ok := inferSourceKind("./some/path") + require.True(t, ok) + require.Equal(t, extensions.SourceKindFile, kind) + }) + + t.Run("PlainNameIsNotLocation", func(t *testing.T) { + _, ok := inferSourceKind("my-source") + require.False(t, ok) + }) +} + +func TestDefaultSourceName(t *testing.T) { + t.Parallel() + + cases := map[string]string{ + "https://example.com/registry.json": "example-com", + "https://link/to/registry.json": "link", + "/path/to/registry.json": "registry", + "./custom.json": "custom", + } + + for location, expected := range cases { + require.Equal(t, expected, defaultSourceName(location), "location %q", location) + } +} + +func TestResolveSourceLocation_ExistingSourceUnchanged(t *testing.T) { + t.Parallel() + + action, _ := newBundleInstallTestAction(t) + require.NoError(t, action.sourceManager.Add(context.Background(), "my-source", &extensions.SourceConfig{ + Name: "my-source", + Type: extensions.SourceKindUrl, + Location: "https://example.com/registry.json", + })) + + action.flags.source = "my-source" + require.NoError(t, action.resolveSourceLocation(context.Background())) + require.Equal(t, "my-source", action.flags.source) +} + +func TestResolveSourceLocation_PlainNameUnchanged(t *testing.T) { + t.Parallel() + + action, _ := newBundleInstallTestAction(t) + action.flags.source = "not-a-location" + require.NoError(t, action.resolveSourceLocation(context.Background())) + require.Equal(t, "not-a-location", action.flags.source) +} + +func TestResolveSourceLocation_NoPromptDirectsToSourceAdd(t *testing.T) { + t.Parallel() + + action, _ := newBundleInstallTestAction(t) + action.flags.global.NoPrompt = true + action.flags.source = "https://example.com/registry.json" + + err := action.resolveSourceLocation(context.Background()) + require.Error(t, err) + require.ErrorAs(t, err, new(*internal.ErrorWithSuggestion)) +} + +func TestResolveSourceLocation_FileRegistersSource(t *testing.T) { + t.Parallel() + + registryPath := writeRegistryFile(t) + + console := mockinput.NewMockConsole() + console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).Respond("local-dev") + + action, _ := newBundleInstallTestAction(t) + action.console = console + action.flags.source = registryPath + + require.NoError(t, action.resolveSourceLocation(context.Background())) + require.Equal(t, "local-dev", action.flags.source) + + src, err := action.sourceManager.Get(context.Background(), "local-dev") + require.NoError(t, err) + require.Equal(t, extensions.SourceKindFile, src.Type) + require.Equal(t, registryPath, src.Location) +} + +func TestResolveSourceLocation_FileUsesDefaultNameWhenBlank(t *testing.T) { + t.Parallel() + + registryPath := writeRegistryFile(t) + + console := mockinput.NewMockConsole() + console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).Respond("") + + action, _ := newBundleInstallTestAction(t) + action.console = console + action.flags.source = registryPath + + require.NoError(t, action.resolveSourceLocation(context.Background())) + require.Equal(t, "registry", action.flags.source) + + _, err := action.sourceManager.Get(context.Background(), "registry") + require.NoError(t, err) +} + +func TestResolveSourceLocation_UrlDeclinedReturnsError(t *testing.T) { + t.Parallel() + + console := mockinput.NewMockConsole() + console.WhenConfirm(func(input.ConsoleOptions) bool { return true }).Respond(false) + + action, _ := newBundleInstallTestAction(t) + action.console = console + action.flags.source = "https://example.com/registry.json" + + err := action.resolveSourceLocation(context.Background()) + require.Error(t, err) + require.ErrorAs(t, err, new(*internal.ErrorWithSuggestion)) + + // No source should have been registered. + _, getErr := action.sourceManager.Get(context.Background(), "example-com") + require.ErrorIs(t, getErr, extensions.ErrSourceNotFound) +} + +// writeRegistryFile writes a minimal valid registry.json to a temp dir and +// returns its absolute path. +func writeRegistryFile(t *testing.T) string { + t.Helper() + + dir := t.TempDir() + registry := &extensions.Registry{ + SchemaVersion: extensions.CurrentRegistrySchemaVersion, + Extensions: []*extensions.ExtensionMetadata{ + { + Id: "test.ext", + DisplayName: "Test Extension", + Versions: []extensions.ExtensionVersion{ + {Version: "1.0.0", Artifacts: map[string]extensions.ExtensionArtifact{ + "linux/amd64": {URL: "artifacts/ext.tar.gz"}, + }}, + }, + }, + }, + } + data, err := json.Marshal(registry) + require.NoError(t, err) + + registryPath := filepath.Join(dir, "registry.json") + require.NoError(t, os.WriteFile(registryPath, data, 0600)) + return registryPath +} diff --git a/cli/azd/cmd/testdata/TestFigSpec.ts b/cli/azd/cmd/testdata/TestFigSpec.ts index 5dccd531413..81d00513966 100644 --- a/cli/azd/cmd/testdata/TestFigSpec.ts +++ b/cli/azd/cmd/testdata/TestFigSpec.ts @@ -5668,7 +5668,7 @@ const completionSpec: Fig.Spec = { }, { name: ['--source', '-s'], - description: 'The extension source to use for installs', + description: 'The extension source to use for installs. Accepts a registered source name or a registry location (URL or file path) to register and install from.', args: [ { name: 'source', diff --git a/cli/azd/cmd/testdata/TestUsage-azd-extension-install.snap b/cli/azd/cmd/testdata/TestUsage-azd-extension-install.snap index bebe6f104c8..85cd72088c0 100644 --- a/cli/azd/cmd/testdata/TestUsage-azd-extension-install.snap +++ b/cli/azd/cmd/testdata/TestUsage-azd-extension-install.snap @@ -6,7 +6,7 @@ Usage Flags -f, --force : Force installation, including downgrades and reinstalls - -s, --source string : The extension source to use for installs + -s, --source string : The extension source to use for installs. Accepts a registered source name or a registry location (URL or file path) to register and install from. -v, --version string : The version of the extension to install Global Flags diff --git a/cli/azd/docs/extensions/extension-framework.md b/cli/azd/docs/extensions/extension-framework.md index 527a591eddf..36c457ff5db 100644 --- a/cli/azd/docs/extensions/extension-framework.md +++ b/cli/azd/docs/extensions/extension-framework.md @@ -132,7 +132,13 @@ Shows detailed information for a specific extension, including description, tags Installs one or more extensions from any configured extension source. - `-v, --version` Specifies the exact version to install. -- `-s, --source` Specifies the extension source used for installations. +- `-s, --source` Specifies the extension source used for installations. In addition to the name of a registered source, this accepts a registry location (a URL or file path). When a location is provided, `azd` registers it as a new persisted source — prompting for a source name, and confirming first when the location is a URL — and then installs from it. This lets you install in one step without a separate `azd extension source add`: + + ```bash + azd extension install -s https://link/to/registry.json + ``` + + Under `--no-prompt`, registering a source from a location is not allowed; add the source first with `azd extension source add`. #### `azd extension uninstall [flags]` From 1b594747dc0f62f54d286e54dc140ea7f9f489d0 Mon Sep 17 00:00:00 2001 From: Jeffrey Chen Date: Thu, 25 Jun 2026 17:57:43 +0000 Subject: [PATCH 3/8] Add direct registry source support to extension commands --- cli/azd/cmd/extension.go | 261 +++++++++++++----- cli/azd/cmd/extension_install_source_test.go | 169 +++++++++++- cli/azd/cmd/extension_source_location_test.go | 257 +++++++++++++++++ cli/azd/cmd/extension_test.go | 2 + cli/azd/cmd/extension_upgrade_test.go | 16 +- cli/azd/cmd/testdata/TestFigSpec.ts | 8 +- .../TestUsage-azd-extension-list.snap | 2 +- .../TestUsage-azd-extension-show.snap | 2 +- .../TestUsage-azd-extension-upgrade.snap | 2 +- .../docs/extensions/extension-framework.md | 10 +- cli/azd/pkg/extensions/manager.go | 20 +- 11 files changed, 654 insertions(+), 95 deletions(-) create mode 100644 cli/azd/cmd/extension_source_location_test.go diff --git a/cli/azd/cmd/extension.go b/cli/azd/cmd/extension.go index 0ef40ac9248..6a643752b79 100644 --- a/cli/azd/cmd/extension.go +++ b/cli/azd/cmd/extension.go @@ -55,6 +55,11 @@ func extensionActions(root *actions.ActionDescriptor) *actions.ActionDescriptor Command: &cobra.Command{ Use: "list [--installed]", Short: "List available extensions.", + Long: `List available extensions from registered extension sources. + +The --source flag accepts a registered source name or registry location (URL or +file path). Locations are queried read-only and are not registered. Extensions +from an unregistered location show the location itself in the SOURCE column.`, }, OutputFormats: []output.Format{output.JsonFormat, output.TableFormat}, DefaultFormat: output.TableFormat, @@ -67,6 +72,10 @@ func extensionActions(root *actions.ActionDescriptor) *actions.ActionDescriptor Command: &cobra.Command{ Use: "show ", Short: "Show details for a specific extension.", + Long: `Show details for a specific extension from a registered extension source. + +The --source flag accepts a registered source name or registry location (URL or +file path). Locations are queried read-only and are not registered.`, }, OutputFormats: []output.Format{output.JsonFormat, output.NoneFormat}, DefaultFormat: output.NoneFormat, @@ -81,11 +90,10 @@ func extensionActions(root *actions.ActionDescriptor) *actions.ActionDescriptor Short: "Installs specified extensions.", Long: `Installs one or more extensions by id from a registered extension source. -The --source flag also accepts a registry location (a URL or file path) instead -of the name of a registered source. When a location is given, azd registers it as -a new source (prompting for a name, and confirming first for a URL) and then -installs from it. The source is persisted so later upgrade/list/show commands work -against it without re-specifying the location. +The --source flag also accepts a registry location (URL or file path). When a +location is given, azd registers it as a source (prompting for a name, and +confirming first for a URL) and then installs from it. If the location is already +registered, azd reuses that source. You can also pass the path to a self-contained extension bundle (.zip): azd extracts it and installs the bundled extension. Bundled extensions aren't @@ -117,9 +125,12 @@ source is unavailable, falls back to the main (azd) registry. Extensions that were installed from a non-main registry (e.g., dev) are automatically promoted to the main registry when a newer version is available there. -Use --source to explicitly override the registry source for the upgrade. Use ---all to upgrade all installed extensions in a single batch; failures in one -extension do not prevent the remaining extensions from being upgraded. +Use --source to override the registry source for the upgrade. It accepts a +registered source name or registry location (URL or file path); locations are +registered first and the upgraded extension's stored source is updated. Because +registration is interactive, locations are rejected under --no-prompt. Use --all +to upgrade all installed extensions in a single batch; failures in one extension +do not prevent the remaining extensions from being upgraded. When upgrading an extension that has dependencies, any installed dependencies are automatically upgraded too, to the highest version @@ -201,7 +212,8 @@ type extensionListFlags struct { func newExtensionListFlags(cmd *cobra.Command) *extensionListFlags { flags := &extensionListFlags{} cmd.Flags().BoolVar(&flags.installed, "installed", false, "List installed extensions") - cmd.Flags().StringVar(&flags.source, "source", "", "Filter extensions by source") + cmd.Flags().StringVarP(&flags.source, "source", "s", "", + "Filter extensions by registered source name or registry location (URL or file path).") cmd.Flags().StringSliceVar(&flags.tags, "tags", nil, "Filter extensions by tags") return flags @@ -255,7 +267,14 @@ func (a *extensionListAction) Run(ctx context.Context) (*actions.ActionResult, e Tags: a.flags.tags, } - if options.Source != "" { + sourceConfig, err := resolveReadOnlySourceFilter(ctx, a.sourceManager, a.flags.source) + if err != nil { + return nil, err + } + if sourceConfig != nil { + options.SourceConfig = sourceConfig + options.Source = "" + } else if options.Source != "" { if _, err := a.sourceManager.Get(ctx, options.Source); err != nil { return nil, fmt.Errorf("extension source '%s' not found: %w", options.Source, err) } @@ -527,7 +546,8 @@ func newExtensionShowFlags(cmd *cobra.Command, global *internal.GlobalCommandOpt flags := &extensionShowFlags{ global: global, } - cmd.Flags().StringVarP(&flags.source, "source", "s", "", "The extension source to use.") + cmd.Flags().StringVarP(&flags.source, "source", "s", "", + "The registered source name or registry location (URL or file path) to use.") return flags } @@ -537,6 +557,7 @@ type extensionShowAction struct { console input.Console formatter output.Formatter writer io.Writer + sourceManager *extensions.SourceManager extensionManager *extensions.Manager } @@ -546,6 +567,7 @@ func newExtensionShowAction( console input.Console, formatter output.Formatter, writer io.Writer, + sourceManager *extensions.SourceManager, extensionManager *extensions.Manager, ) actions.Action { return &extensionShowAction{ @@ -554,6 +576,7 @@ func newExtensionShowAction( console: console, formatter: formatter, writer: writer, + sourceManager: sourceManager, extensionManager: extensionManager, } } @@ -709,6 +732,15 @@ func (a *extensionShowAction) Run(ctx context.Context) (*actions.ActionResult, e Id: extensionId, } + sourceConfig, err := resolveReadOnlySourceFilter(ctx, a.sourceManager, a.flags.source) + if err != nil { + return nil, err + } + if sourceConfig != nil { + filterOptions.SourceConfig = sourceConfig + filterOptions.Source = "" + } + extensionMatches, err := a.extensionManager.FindExtensions(ctx, filterOptions) if err != nil { return nil, fmt.Errorf("failed to find extension: %w", err) @@ -1380,65 +1412,88 @@ func normalizeBundleSourceName(name string) string { return strings.Trim(sb.String(), "-") } -// resolveSourceLocation handles the case where -s/--source points directly at a -// registry location (URL or file path) rather than the name of an -// already-registered source. When a location is detected, it confirms before -// registering an untrusted URL, prompts for a source name, persists the source, -// and rewrites a.flags.source to the registered name so the install loop -// resolves extensions from it. Registered source names and values that do not -// look like a location are left unchanged. +// resolveSourceLocation registers a direct --source location and rewrites it to +// the registered source name. func (a *extensionInstallAction) resolveSourceLocation(ctx context.Context) error { - if a.flags.source == "" { - return nil + resolved, err := registerSourceFromLocation( + ctx, a.console, a.sourceManager, a.extensionManager, a.flags.source, a.flags.global.NoPrompt) + if err != nil { + return err + } + a.flags.source = resolved + return nil +} + +// registerSourceFromLocation persists a direct --source location for mutating +// commands, reusing an existing source with the same location when possible. +// Registered names and non-location values are returned unchanged. +func registerSourceFromLocation( + ctx context.Context, + console input.Console, + sourceManager *extensions.SourceManager, + extensionManager *extensions.Manager, + source string, + noPrompt bool, +) (string, error) { + if source == "" { + return source, nil } // If the value already names a registered source, keep current behavior. - _, err := a.sourceManager.Get(ctx, a.flags.source) + _, err := sourceManager.Get(ctx, source) if err == nil { - return nil + return source, nil } if !errors.Is(err, extensions.ErrSourceNotFound) { - return fmt.Errorf("failed to resolve extension source %q: %w", a.flags.source, err) + return "", fmt.Errorf("failed to resolve extension source %q: %w", source, err) } - // Not a registered source — detect whether it is a registry location. - location := a.flags.source + location := source kind, ok := inferSourceKind(location) if !ok { - // Not a location; leave the value untouched so existing resolution and - // error messaging applies. - return nil + return source, nil } - // Registering a source is interactive (naming + trust confirmation), so in - // --no-prompt mode direct the user to add the source explicitly first. - if a.flags.global.NoPrompt { - return &internal.ErrorWithSuggestion{ + if kind == extensions.SourceKindFile { + if abs, err := filepath.Abs(location); err == nil { + location = abs + } + } + + existing, err := findSourceByLocation(ctx, sourceManager, kind, location) + if err != nil { + return "", err + } + if existing != nil { + return existing.Name, nil + } + + if noPrompt { + return "", &internal.ErrorWithSuggestion{ Err: fmt.Errorf( "cannot register a new extension source from %q while --no-prompt is set", location), Suggestion: fmt.Sprintf( - "Add the source first with %s, then install with %s.", + "Add the source first with %s, then re-run with %s.", output.WithHighLightFormat( "azd extension source add -n -t %s -l %q", kind, location), - output.WithHighLightFormat("azd extension install -s "), + output.WithHighLightFormat("-s "), ), } } - // Confirm before registering a URL source, which may be untrusted. if kind == extensions.SourceKindUrl { - a.console.Message(ctx, "") - confirm, err := a.console.Confirm(ctx, input.ConsoleOptions{ + console.Message(ctx, "") + confirm, err := console.Confirm(ctx, input.ConsoleOptions{ Message: fmt.Sprintf( - "Register and install from the extension source at %s?", + "Register and use the extension source at %s?", output.WithHighLightFormat(location)), DefaultValue: false, }) if err != nil { - return err + return "", err } if !confirm { - return &internal.ErrorWithSuggestion{ + return "", &internal.ErrorWithSuggestion{ Err: errors.New("extension source registration declined"), Suggestion: "Re-run and confirm to register the source, " + "or add it explicitly with 'azd extension source add'.", @@ -1446,14 +1501,13 @@ func (a *extensionInstallAction) resolveSourceLocation(ctx context.Context) erro } } - // Prompt for a source name with a sensible default derived from the location. defaultName := defaultSourceName(location) - sourceName, err := a.console.Prompt(ctx, input.ConsoleOptions{ + sourceName, err := console.Prompt(ctx, input.ConsoleOptions{ Message: "Enter a name for this extension source:", DefaultValue: defaultName, }) if err != nil { - return err + return "", err } sourceName = strings.TrimSpace(sourceName) if sourceName == "" { @@ -1467,30 +1521,101 @@ func (a *extensionInstallAction) resolveSourceLocation(ctx context.Context) erro } spinnerMessage := fmt.Sprintf("Registering extension source %s", output.WithHighLightFormat(sourceName)) - a.console.ShowSpinner(ctx, spinnerMessage, input.Step) + console.ShowSpinner(ctx, spinnerMessage, input.Step) - // Validate the source by hydrating it before persisting. - if _, err := a.sourceManager.CreateSource(ctx, sourceConfig); err != nil { - a.console.StopSpinner(ctx, spinnerMessage, input.StepFailed) - return fmt.Errorf("failed to validate extension source: %w", err) + if _, err := sourceManager.CreateSource(ctx, sourceConfig); err != nil { + console.StopSpinner(ctx, spinnerMessage, input.StepFailed) + return "", fmt.Errorf("failed to validate extension source: %w", err) } - if err := a.sourceManager.Add(ctx, sourceName, sourceConfig); err != nil { - a.console.StopSpinner(ctx, spinnerMessage, input.StepFailed) - return fmt.Errorf("failed to add extension source: %w", err) + if err := sourceManager.Add(ctx, sourceName, sourceConfig); err != nil { + console.StopSpinner(ctx, spinnerMessage, input.StepFailed) + return "", fmt.Errorf("failed to add extension source: %w", err) } - a.console.StopSpinner(ctx, spinnerMessage, input.StepDone) + console.StopSpinner(ctx, spinnerMessage, input.StepDone) - // Refresh manager caches so the new source is visible and the cached config - // snapshot is not clobbered when the install below saves. - a.extensionManager.InvalidateSourceCache() - if err := a.extensionManager.ReloadUserConfig(); err != nil { - return err + extensionManager.InvalidateSourceCache() + if err := extensionManager.ReloadUserConfig(); err != nil { + return "", err } - // Add normalizes the persisted name; resolve extensions against that name. - a.flags.source = sourceConfig.Name - return nil + return sourceConfig.Name, nil +} + +// findSourceByLocation returns the registered source for location, if any. +func findSourceByLocation( + ctx context.Context, + sourceManager *extensions.SourceManager, + kind extensions.SourceKind, + location string, +) (*extensions.SourceConfig, error) { + sources, err := sourceManager.List(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list extension sources: %w", err) + } + + for _, source := range sources { + if source.Type == kind && locationsEqual(kind, source.Location, location) { + return source, nil + } + } + return nil, nil +} + +// locationsEqual reports whether two locations refer to the same source. +func locationsEqual(kind extensions.SourceKind, a, b string) bool { + switch kind { + case extensions.SourceKindUrl: + return strings.EqualFold(a, b) + case extensions.SourceKindFile: + return absPath(a) == absPath(b) + default: + return a == b + } +} + +// absPath returns path as an absolute path when possible. +func absPath(path string) string { + if abs, err := filepath.Abs(path); err == nil { + return abs + } + return path +} + +// resolveReadOnlySourceFilter returns a temporary source config for read-only +// commands when --source is a registry location. Registered names return nil. +func resolveReadOnlySourceFilter( + ctx context.Context, + sourceManager *extensions.SourceManager, + source string, +) (*extensions.SourceConfig, error) { + if source == "" { + return nil, nil + } + + _, err := sourceManager.Get(ctx, source) + if err == nil { + return nil, nil + } + if !errors.Is(err, extensions.ErrSourceNotFound) { + return nil, fmt.Errorf("failed to resolve extension source %q: %w", source, err) + } + + kind, ok := inferSourceKind(source) + if !ok { + return nil, nil + } + + location := source + if kind == extensions.SourceKindFile { + location = absPath(location) + } + + return &extensions.SourceConfig{ + Name: location, + Type: kind, + Location: location, + }, nil } // inferSourceKind infers the extension source kind from a registry location, @@ -1498,7 +1623,8 @@ func (a *extensionInstallAction) resolveSourceLocation(ctx context.Context) erro // It reports false when the value does not look like a location and is more // likely the name of a source. func inferSourceKind(location string) (extensions.SourceKind, bool) { - if strings.HasPrefix(location, "http://") || strings.HasPrefix(location, "https://") { + lower := strings.ToLower(location) + if strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://") { return extensions.SourceKindUrl, true } if info, err := os.Stat(location); err == nil && !info.IsDir() { @@ -1648,7 +1774,8 @@ func newExtensionUpgradeFlags(cmd *cobra.Command, global *internal.GlobalCommand global: global, } cmd.Flags().StringVarP(&flags.version, "version", "v", "", "The version of the extension to upgrade to") - cmd.Flags().StringVarP(&flags.source, "source", "s", "", "The extension source to use for upgrades") + cmd.Flags().StringVarP(&flags.source, "source", "s", "", + "The registered source name or registry location (URL or file path) to use for upgrades.") cmd.Flags().BoolVar(&flags.all, "all", false, "Upgrade all installed extensions") cmd.Flags().BoolVar(&flags.noDependencyUpgrades, "no-dependency-upgrades", false, "Do not upgrade dependencies when upgrading an extension that has dependencies") @@ -1663,6 +1790,7 @@ type extensionUpgradeAction struct { formatter output.Formatter writer io.Writer console input.Console + sourceManager *extensions.SourceManager extensionManager *extensions.Manager } @@ -1672,6 +1800,7 @@ func newExtensionUpgradeAction( formatter output.Formatter, writer io.Writer, console input.Console, + sourceManager *extensions.SourceManager, extensionManager *extensions.Manager, ) actions.Action { return &extensionUpgradeAction{ @@ -1680,6 +1809,7 @@ func newExtensionUpgradeAction( formatter: formatter, writer: writer, console: console, + sourceManager: sourceManager, extensionManager: extensionManager, } } @@ -1733,6 +1863,13 @@ func (a *extensionUpgradeAction) Run( }) } + resolvedSource, err := registerSourceFromLocation( + ctx, a.console, a.sourceManager, a.extensionManager, a.flags.source, a.flags.global.NoPrompt) + if err != nil { + return nil, err + } + a.flags.source = resolvedSource + azdVersion := currentAzdSemver() extensionIds := a.args diff --git a/cli/azd/cmd/extension_install_source_test.go b/cli/azd/cmd/extension_install_source_test.go index 369c4abad00..5186252f9d5 100644 --- a/cli/azd/cmd/extension_install_source_test.go +++ b/cli/azd/cmd/extension_install_source_test.go @@ -4,15 +4,18 @@ package cmd import ( - "context" "encoding/json" + "net/http" "os" "path/filepath" "testing" "github.com/azure/azure-dev/cli/azd/internal" + "github.com/azure/azure-dev/cli/azd/pkg/config" "github.com/azure/azure-dev/cli/azd/pkg/extensions" "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/lazy" + "github.com/azure/azure-dev/cli/azd/test/mocks" "github.com/azure/azure-dev/cli/azd/test/mocks/mockinput" "github.com/stretchr/testify/require" ) @@ -36,6 +39,18 @@ func TestInferSourceKind(t *testing.T) { require.Equal(t, extensions.SourceKindUrl, kind) }) + t.Run("MixedCaseHttpsUrl", func(t *testing.T) { + kind, ok := inferSourceKind("HTTPS://Example.com/registry.json") + require.True(t, ok) + require.Equal(t, extensions.SourceKindUrl, kind) + }) + + t.Run("UpperCaseHttpUrl", func(t *testing.T) { + kind, ok := inferSourceKind("HTTP://example.com/registry.json") + require.True(t, ok) + require.Equal(t, extensions.SourceKindUrl, kind) + }) + t.Run("ExistingFile", func(t *testing.T) { kind, ok := inferSourceKind(existing) require.True(t, ok) @@ -79,14 +94,14 @@ func TestResolveSourceLocation_ExistingSourceUnchanged(t *testing.T) { t.Parallel() action, _ := newBundleInstallTestAction(t) - require.NoError(t, action.sourceManager.Add(context.Background(), "my-source", &extensions.SourceConfig{ + require.NoError(t, action.sourceManager.Add(t.Context(), "my-source", &extensions.SourceConfig{ Name: "my-source", Type: extensions.SourceKindUrl, Location: "https://example.com/registry.json", })) action.flags.source = "my-source" - require.NoError(t, action.resolveSourceLocation(context.Background())) + require.NoError(t, action.resolveSourceLocation(t.Context())) require.Equal(t, "my-source", action.flags.source) } @@ -95,7 +110,7 @@ func TestResolveSourceLocation_PlainNameUnchanged(t *testing.T) { action, _ := newBundleInstallTestAction(t) action.flags.source = "not-a-location" - require.NoError(t, action.resolveSourceLocation(context.Background())) + require.NoError(t, action.resolveSourceLocation(t.Context())) require.Equal(t, "not-a-location", action.flags.source) } @@ -106,7 +121,7 @@ func TestResolveSourceLocation_NoPromptDirectsToSourceAdd(t *testing.T) { action.flags.global.NoPrompt = true action.flags.source = "https://example.com/registry.json" - err := action.resolveSourceLocation(context.Background()) + err := action.resolveSourceLocation(t.Context()) require.Error(t, err) require.ErrorAs(t, err, new(*internal.ErrorWithSuggestion)) } @@ -123,10 +138,10 @@ func TestResolveSourceLocation_FileRegistersSource(t *testing.T) { action.console = console action.flags.source = registryPath - require.NoError(t, action.resolveSourceLocation(context.Background())) + require.NoError(t, action.resolveSourceLocation(t.Context())) require.Equal(t, "local-dev", action.flags.source) - src, err := action.sourceManager.Get(context.Background(), "local-dev") + src, err := action.sourceManager.Get(t.Context(), "local-dev") require.NoError(t, err) require.Equal(t, extensions.SourceKindFile, src.Type) require.Equal(t, registryPath, src.Location) @@ -144,10 +159,10 @@ func TestResolveSourceLocation_FileUsesDefaultNameWhenBlank(t *testing.T) { action.console = console action.flags.source = registryPath - require.NoError(t, action.resolveSourceLocation(context.Background())) + require.NoError(t, action.resolveSourceLocation(t.Context())) require.Equal(t, "registry", action.flags.source) - _, err := action.sourceManager.Get(context.Background(), "registry") + _, err := action.sourceManager.Get(t.Context(), "registry") require.NoError(t, err) } @@ -161,17 +176,143 @@ func TestResolveSourceLocation_UrlDeclinedReturnsError(t *testing.T) { action.console = console action.flags.source = "https://example.com/registry.json" - err := action.resolveSourceLocation(context.Background()) + err := action.resolveSourceLocation(t.Context()) require.Error(t, err) require.ErrorAs(t, err, new(*internal.ErrorWithSuggestion)) - // No source should have been registered. - _, getErr := action.sourceManager.Get(context.Background(), "example-com") + _, getErr := action.sourceManager.Get(t.Context(), "example-com") require.ErrorIs(t, getErr, extensions.ErrSourceNotFound) } -// writeRegistryFile writes a minimal valid registry.json to a temp dir and -// returns its absolute path. +func TestResolveSourceLocation_ExistingUrlLocationReused(t *testing.T) { + t.Parallel() + + action, _ := newBundleInstallTestAction(t) + require.NoError(t, action.sourceManager.Add(t.Context(), "myreg", &extensions.SourceConfig{ + Name: "myreg", + Type: extensions.SourceKindUrl, + Location: "https://example.com/registry.json", + })) + + action.flags.source = "HTTPS://example.com/registry.json" + require.NoError(t, action.resolveSourceLocation(t.Context())) + require.Equal(t, "myreg", action.flags.source) + + sources, err := action.sourceManager.List(t.Context()) + require.NoError(t, err) + matches := 0 + for _, src := range sources { + if src.Location == "https://example.com/registry.json" { + matches++ + } + } + require.Equal(t, 1, matches) +} + +func TestResolveSourceLocation_ExistingFileLocationReusedFromRelativePath(t *testing.T) { + registryPath := writeRegistryFile(t) + dir := filepath.Dir(registryPath) + + action, _ := newBundleInstallTestAction(t) + require.NoError(t, action.sourceManager.Add(t.Context(), "filereg", &extensions.SourceConfig{ + Name: "filereg", + Type: extensions.SourceKindFile, + Location: registryPath, + })) + + t.Chdir(dir) + action.flags.source = "registry.json" + require.NoError(t, action.resolveSourceLocation(t.Context())) + require.Equal(t, "filereg", action.flags.source) +} + +func TestResolveSourceLocation_FilePersistsAbsolutePath(t *testing.T) { + registryPath := writeRegistryFile(t) + dir := filepath.Dir(registryPath) + + console := mockinput.NewMockConsole() + console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).Respond("local-dev") + + action, _ := newBundleInstallTestAction(t) + action.console = console + + t.Chdir(dir) + action.flags.source = "registry.json" + require.NoError(t, action.resolveSourceLocation(t.Context())) + require.Equal(t, "local-dev", action.flags.source) + + src, err := action.sourceManager.Get(t.Context(), "local-dev") + require.NoError(t, err) + require.True(t, filepath.IsAbs(src.Location), "location %q should be absolute", src.Location) + require.Equal(t, registryPath, src.Location) +} + +func TestResolveSourceLocation_UrlAcceptedRegistersSource(t *testing.T) { + t.Parallel() + + action, mockContext := newInstallSourceTestAction(t) + mockContext.HttpClient.When(func(req *http.Request) bool { + return req.URL.String() == "https://example.com/registry.json" + }).RespondFn(func(req *http.Request) (*http.Response, error) { + return mocks.CreateHttpResponseWithBody(req, http.StatusOK, extensions.Registry{ + SchemaVersion: extensions.CurrentRegistrySchemaVersion, + }) + }) + + mockContext.Console.WhenConfirm(func(input.ConsoleOptions) bool { return true }).Respond(true) + mockContext.Console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).Respond("example-registry") + + action.flags.source = "https://example.com/registry.json" + require.NoError(t, action.resolveSourceLocation(t.Context())) + require.Equal(t, "example-registry", action.flags.source) + + src, err := action.sourceManager.Get(t.Context(), "example-registry") + require.NoError(t, err) + require.Equal(t, extensions.SourceKindUrl, src.Type) + require.Equal(t, "https://example.com/registry.json", src.Location) +} + +func TestResolveSourceLocation_NoPromptFileDirectsToSourceAdd(t *testing.T) { + t.Parallel() + + registryPath := writeRegistryFile(t) + + action, _ := newBundleInstallTestAction(t) + action.flags.global.NoPrompt = true + action.flags.source = registryPath + + err := action.resolveSourceLocation(t.Context()) + require.Error(t, err) + require.ErrorAs(t, err, new(*internal.ErrorWithSuggestion)) + + sources, err := action.sourceManager.List(t.Context()) + require.NoError(t, err) + for _, src := range sources { + require.NotEqual(t, registryPath, src.Location, "the file source must not be registered") + } +} + +func newInstallSourceTestAction(t *testing.T) (*extensionInstallAction, *mocks.MockContext) { + t.Helper() + + mockContext := mocks.NewMockContext(t.Context()) + userConfigManager := config.NewUserConfigManager(mockContext.ConfigManager) + sourceManager := extensions.NewSourceManager(mockContext.Container, userConfigManager, mockContext.HttpClient) + lazyRunner := lazy.NewLazy(func() (*extensions.Runner, error) { + return extensions.NewRunner(mockContext.CommandRunner), nil + }) + manager, err := extensions.NewManager(userConfigManager, sourceManager, lazyRunner, mockContext.HttpClient) + require.NoError(t, err) + + action := &extensionInstallAction{ + console: mockContext.Console, + extensionManager: manager, + sourceManager: sourceManager, + flags: &extensionInstallFlags{global: &internal.GlobalCommandOptions{}}, + } + return action, mockContext +} + func writeRegistryFile(t *testing.T) string { t.Helper() diff --git a/cli/azd/cmd/extension_source_location_test.go b/cli/azd/cmd/extension_source_location_test.go new file mode 100644 index 00000000000..1e63067fe95 --- /dev/null +++ b/cli/azd/cmd/extension_source_location_test.go @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "bytes" + "encoding/json" + "net/http" + "path/filepath" + "testing" + + "github.com/azure/azure-dev/cli/azd/internal" + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/azure/azure-dev/cli/azd/pkg/extensions" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/lazy" + "github.com/azure/azure-dev/cli/azd/pkg/output" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/stretchr/testify/require" +) + +const sourceLocationRegistryURL = "https://example.com/registry.json" + +func stubRegistryHTTP(mockContext *mocks.MockContext) { + mockContext.HttpClient.When(func(req *http.Request) bool { + return req.URL.String() == sourceLocationRegistryURL + }).RespondFn(func(req *http.Request) (*http.Response, error) { + return mocks.CreateHttpResponseWithBody(req, http.StatusOK, extensions.Registry{ + SchemaVersion: extensions.CurrentRegistrySchemaVersion, + Extensions: []*extensions.ExtensionMetadata{ + { + Id: "test.ext", + DisplayName: "Test Extension", + Versions: []extensions.ExtensionVersion{ + {Version: "1.0.0"}, + }, + }, + }, + }) + }) +} + +func newSourceLocationTestManager( + t *testing.T, +) (*mocks.MockContext, *extensions.Manager, *extensions.SourceManager) { + t.Helper() + + mockContext := mocks.NewMockContext(t.Context()) + userConfigManager := config.NewUserConfigManager(mockContext.ConfigManager) + sourceManager := extensions.NewSourceManager(mockContext.Container, userConfigManager, mockContext.HttpClient) + lazyRunner := lazy.NewLazy(func() (*extensions.Runner, error) { + return extensions.NewRunner(mockContext.CommandRunner), nil + }) + manager, err := extensions.NewManager(userConfigManager, sourceManager, lazyRunner, mockContext.HttpClient) + require.NoError(t, err) + + return mockContext, manager, sourceManager +} + +func TestExtensionList_DirectUrlSource(t *testing.T) { + t.Parallel() + + mockContext, manager, sourceManager := newSourceLocationTestManager(t) + stubRegistryHTTP(mockContext) + + var buf bytes.Buffer + action := &extensionListAction{ + flags: &extensionListFlags{source: sourceLocationRegistryURL}, + formatter: &output.JsonFormatter{}, + console: mockContext.Console, + writer: &buf, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.NoError(t, err) + + var rows []extensionListItem + require.NoError(t, json.Unmarshal(buf.Bytes(), &rows)) + require.Len(t, rows, 1) + require.Equal(t, "test.ext", rows[0].Id) + require.Equal(t, sourceLocationRegistryURL, rows[0].Source) + + requireLocationNotRegistered(t, sourceManager, sourceLocationRegistryURL) +} + +func TestExtensionList_DirectUrlSourceDoesNotPrompt(t *testing.T) { + t.Parallel() + + mockContext, manager, sourceManager := newSourceLocationTestManager(t) + stubRegistryHTTP(mockContext) + + var buf bytes.Buffer + action := &extensionListAction{ + flags: &extensionListFlags{ + source: sourceLocationRegistryURL, + }, + formatter: &output.JsonFormatter{}, + console: mockContext.Console, + writer: &buf, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.NoError(t, err) + + var rows []extensionListItem + require.NoError(t, json.Unmarshal(buf.Bytes(), &rows)) + require.Len(t, rows, 1) +} + +func TestExtensionList_UnknownRegisteredNameErrors(t *testing.T) { + t.Parallel() + + _, manager, sourceManager := newSourceLocationTestManager(t) + + var buf bytes.Buffer + action := &extensionListAction{ + flags: &extensionListFlags{source: "not-a-registered-source"}, + formatter: &output.JsonFormatter{}, + console: mocks.NewMockContext(t.Context()).Console, + writer: &buf, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.Error(t, err) +} + +func TestExtensionList_DirectRelativeFileSource(t *testing.T) { + registryPath := writeRegistryFile(t) + t.Chdir(filepath.Dir(registryPath)) + + _, manager, sourceManager := newSourceLocationTestManager(t) + + var buf bytes.Buffer + action := &extensionListAction{ + flags: &extensionListFlags{source: "registry.json"}, + formatter: &output.JsonFormatter{}, + console: mocks.NewMockContext(t.Context()).Console, + writer: &buf, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.NoError(t, err) + + var rows []extensionListItem + require.NoError(t, json.Unmarshal(buf.Bytes(), &rows)) + require.Len(t, rows, 1) + require.Equal(t, registryPath, rows[0].Source) + requireLocationNotRegistered(t, sourceManager, registryPath) +} + +func TestExtensionShow_DirectUrlSource(t *testing.T) { + t.Parallel() + + mockContext, manager, sourceManager := newSourceLocationTestManager(t) + stubRegistryHTTP(mockContext) + + action := &extensionShowAction{ + args: []string{"test.ext"}, + flags: &extensionShowFlags{ + source: sourceLocationRegistryURL, + global: &internal.GlobalCommandOptions{}, + }, + console: mockContext.Console, + formatter: &output.NoneFormatter{}, + writer: &bytes.Buffer{}, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.NoError(t, err) + + requireLocationNotRegistered(t, sourceManager, sourceLocationRegistryURL) +} + +func TestExtensionUpgrade_UrlSourceRegistersSource(t *testing.T) { + t.Parallel() + + mockContext, manager, sourceManager := newSourceLocationTestManager(t) + stubRegistryHTTP(mockContext) + + mockContext.Console.WhenConfirm(func(input.ConsoleOptions) bool { return true }).Respond(true) + mockContext.Console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).Respond("example-registry") + + var buf bytes.Buffer + action := &extensionUpgradeAction{ + args: []string{"test.ext"}, + flags: &extensionUpgradeFlags{ + source: sourceLocationRegistryURL, + global: &internal.GlobalCommandOptions{}, + }, + formatter: &output.JsonFormatter{}, + writer: &buf, + console: mockContext.Console, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.Error(t, err) + + src, err := sourceManager.Get(t.Context(), "example-registry") + require.NoError(t, err) + require.Equal(t, extensions.SourceKindUrl, src.Type) + require.Equal(t, sourceLocationRegistryURL, src.Location) + + require.Equal(t, "example-registry", action.flags.source) +} + +func TestExtensionUpgrade_UrlSourceBlockedUnderNoPrompt(t *testing.T) { + t.Parallel() + + mockContext, manager, sourceManager := newSourceLocationTestManager(t) + + var buf bytes.Buffer + action := &extensionUpgradeAction{ + args: []string{"test.ext"}, + flags: &extensionUpgradeFlags{ + source: sourceLocationRegistryURL, + global: &internal.GlobalCommandOptions{NoPrompt: true}, + }, + formatter: &output.JsonFormatter{}, + writer: &buf, + console: mockContext.Console, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.Error(t, err) + require.ErrorAs(t, err, new(*internal.ErrorWithSuggestion)) + + requireLocationNotRegistered(t, sourceManager, sourceLocationRegistryURL) +} + +func requireLocationNotRegistered( + t *testing.T, + sourceManager *extensions.SourceManager, + location string, +) { + t.Helper() + + sources, err := sourceManager.List(t.Context()) + require.NoError(t, err) + for _, src := range sources { + require.NotEqual(t, location, src.Location, "location %q must not be registered", location) + } +} diff --git a/cli/azd/cmd/extension_test.go b/cli/azd/cmd/extension_test.go index 7e9d2ce5930..fe048a30a42 100644 --- a/cli/azd/cmd/extension_test.go +++ b/cli/azd/cmd/extension_test.go @@ -902,6 +902,7 @@ func Test_NewExtensionShowAction(t *testing.T) { mockinput.NewMockConsole(), &output.JsonFormatter{}, &bytes.Buffer{}, + nil, // sourceManager nil, // extensionManager ) require.NotNil(t, action) @@ -938,6 +939,7 @@ func Test_NewExtensionUpgradeAction(t *testing.T) { &output.NoneFormatter{}, &bytes.Buffer{}, mockinput.NewMockConsole(), + nil, // sourceManager nil, // extensionManager ) require.NotNil(t, action) diff --git a/cli/azd/cmd/extension_upgrade_test.go b/cli/azd/cmd/extension_upgrade_test.go index d0c6d436214..ee7110113fe 100644 --- a/cli/azd/cmd/extension_upgrade_test.go +++ b/cli/azd/cmd/extension_upgrade_test.go @@ -54,7 +54,7 @@ func createUpgradeTestManager( installed map[string]*extensions.Extension, registryURL string, registry extensions.Registry, -) *extensions.Manager { +) (*extensions.Manager, *extensions.SourceManager) { t.Helper() userConfigManager := config.NewUserConfigManager(mockCtx.ConfigManager) @@ -95,7 +95,7 @@ func createUpgradeTestManager( ) require.NoError(t, err) - return manager + return manager, sourceManager } // --------------------------------------------------------------------------- @@ -121,7 +121,7 @@ func TestUpgradeAction_ContextCancellation(t *testing.T) { testExtMeta("ext-c", "2.0.0", "test"), ) - manager := createUpgradeTestManager( + manager, sourceManager := createUpgradeTestManager( t, mockCtx, installed, registryURL, registry, ) @@ -139,6 +139,7 @@ func TestUpgradeAction_ContextCancellation(t *testing.T) { &output.JsonFormatter{}, &buf, mockinput.NewMockConsole(), + sourceManager, manager, ) @@ -250,7 +251,7 @@ func TestUpgradeOneExtension(t *testing.T) { t.Parallel() mockCtx := mocks.NewMockContext(context.Background()) - manager := createUpgradeTestManager( + manager, sourceManager := createUpgradeTestManager( t, mockCtx, tt.installed, registryURL, tt.registry, ) @@ -260,6 +261,7 @@ func TestUpgradeOneExtension(t *testing.T) { formatter: &output.JsonFormatter{}, writer: &bytes.Buffer{}, console: mockinput.NewMockConsole(), + sourceManager: sourceManager, extensionManager: manager, } @@ -305,7 +307,7 @@ func TestUpgradeAction_MixedBatch(t *testing.T) { // "missing" not in registry ) - manager := createUpgradeTestManager( + manager, sourceManager := createUpgradeTestManager( t, mockCtx, installed, registryURL, registry, ) @@ -319,6 +321,7 @@ func TestUpgradeAction_MixedBatch(t *testing.T) { &output.JsonFormatter{}, &buf, mockinput.NewMockConsole(), + sourceManager, manager, ) @@ -462,7 +465,7 @@ func TestUpgradeOneExtension_DelistedSkipped(t *testing.T) { // Empty registry — extension no longer listed registry := testRegistry() - manager := createUpgradeTestManager( + manager, sourceManager := createUpgradeTestManager( t, mockCtx, installed, registryURL, registry, ) @@ -474,6 +477,7 @@ func TestUpgradeOneExtension_DelistedSkipped(t *testing.T) { formatter: &output.JsonFormatter{}, writer: &bytes.Buffer{}, console: mockinput.NewMockConsole(), + sourceManager: sourceManager, extensionManager: manager, } diff --git a/cli/azd/cmd/testdata/TestFigSpec.ts b/cli/azd/cmd/testdata/TestFigSpec.ts index 81d00513966..ec81bde92a4 100644 --- a/cli/azd/cmd/testdata/TestFigSpec.ts +++ b/cli/azd/cmd/testdata/TestFigSpec.ts @@ -5700,8 +5700,8 @@ const completionSpec: Fig.Spec = { description: 'List installed extensions', }, { - name: ['--source'], - description: 'Filter extensions by source', + name: ['--source', '-s'], + description: 'Filter extensions by registered source name or registry location (URL or file path).', args: [ { name: 'source', @@ -5726,7 +5726,7 @@ const completionSpec: Fig.Spec = { options: [ { name: ['--source', '-s'], - description: 'The extension source to use.', + description: 'The registered source name or registry location (URL or file path) to use.', args: [ { name: 'source', @@ -5831,7 +5831,7 @@ const completionSpec: Fig.Spec = { }, { name: ['--source', '-s'], - description: 'The extension source to use for upgrades', + description: 'The registered source name or registry location (URL or file path) to use for upgrades.', args: [ { name: 'source', diff --git a/cli/azd/cmd/testdata/TestUsage-azd-extension-list.snap b/cli/azd/cmd/testdata/TestUsage-azd-extension-list.snap index 3be9cd4a5a0..c63817888b8 100644 --- a/cli/azd/cmd/testdata/TestUsage-azd-extension-list.snap +++ b/cli/azd/cmd/testdata/TestUsage-azd-extension-list.snap @@ -6,7 +6,7 @@ Usage Flags --installed : List installed extensions - --source string : Filter extensions by source + -s, --source string : Filter extensions by registered source name or registry location (URL or file path). --tags strings : Filter extensions by tags Global Flags diff --git a/cli/azd/cmd/testdata/TestUsage-azd-extension-show.snap b/cli/azd/cmd/testdata/TestUsage-azd-extension-show.snap index d94d916b984..42bf727949b 100644 --- a/cli/azd/cmd/testdata/TestUsage-azd-extension-show.snap +++ b/cli/azd/cmd/testdata/TestUsage-azd-extension-show.snap @@ -5,7 +5,7 @@ Usage azd extension show [flags] Flags - -s, --source string : The extension source to use. + -s, --source string : The registered source name or registry location (URL or file path) to use. Global Flags -C, --cwd string : Sets the current working directory. diff --git a/cli/azd/cmd/testdata/TestUsage-azd-extension-upgrade.snap b/cli/azd/cmd/testdata/TestUsage-azd-extension-upgrade.snap index fe6a99bacc5..6ffeb2f433b 100644 --- a/cli/azd/cmd/testdata/TestUsage-azd-extension-upgrade.snap +++ b/cli/azd/cmd/testdata/TestUsage-azd-extension-upgrade.snap @@ -7,7 +7,7 @@ Usage Flags --all : Upgrade all installed extensions --no-dependency-upgrades : Do not upgrade dependencies when upgrading an extension that has dependencies - -s, --source string : The extension source to use for upgrades + -s, --source string : The registered source name or registry location (URL or file path) to use for upgrades. -v, --version string : The version of the extension to upgrade to Global Flags diff --git a/cli/azd/docs/extensions/extension-framework.md b/cli/azd/docs/extensions/extension-framework.md index 36c457ff5db..03af43d7ea1 100644 --- a/cli/azd/docs/extensions/extension-framework.md +++ b/cli/azd/docs/extensions/extension-framework.md @@ -118,26 +118,28 @@ Extensions are a collection of executable artifacts that extend or enhance funct Lists matching extensions from one or more extension sources. - `--installed` When set displays a list of installed extensions. -- `--source` When set will only list extensions from the specified source. +- `-s, --source` Filters by registered source name or registry location (URL or file path). Locations are queried read-only and are not registered. Extensions from an unregistered location show the location itself in the `SOURCE` column. - `--tags` Allows filtering extensions by tags (e.g., AI, test) #### `azd extension show [flags]` Shows detailed information for a specific extension, including description, tags, versions, and installation status. -- `-s, --source` The extension source to use. Use this flag when the same extension ID exists in multiple sources. +- `-s, --source` Uses a registered source name or registry location (URL or file path). Locations are queried read-only and are not registered. #### `azd extension install [flags]` Installs one or more extensions from any configured extension source. - `-v, --version` Specifies the exact version to install. -- `-s, --source` Specifies the extension source used for installations. In addition to the name of a registered source, this accepts a registry location (a URL or file path). When a location is provided, `azd` registers it as a new persisted source — prompting for a source name, and confirming first when the location is a URL — and then installs from it. This lets you install in one step without a separate `azd extension source add`: +- `-s, --source` Specifies the extension source used for installations. In addition to registered source names, this accepts a registry location (URL or file path). `azd` registers the location as a source, prompting for a name and confirming first for URLs, then installs from it: ```bash azd extension install -s https://link/to/registry.json ``` + If the same location is already registered, `azd` reuses that source. File paths are stored as absolute paths. + Under `--no-prompt`, registering a source from a location is not allowed; add the source first with `azd extension source add`. #### `azd extension uninstall [flags]` @@ -152,7 +154,7 @@ Upgrades one or more extensions to the latest versions. - `--all` Upgrades all previously installed extensions when specified. - `-v, --version` Upgrades a specified extension to an exact version, if provided. -- `-s, --source` Specifies the extension source used for installations. +- `-s, --source` Specifies the source used for the upgrade. In addition to registered source names, this accepts a registry location (URL or file path). `azd` registers the location as a source, updates the extension's stored source, and rejects locations under `--no-prompt`; add the source first with `azd extension source add`. - `--no-dependency-upgrades` Skips upgrading dependencies declared by extension packs. ## Developing Extensions diff --git a/cli/azd/pkg/extensions/manager.go b/cli/azd/pkg/extensions/manager.go index 1341cebee7d..07f0ea23699 100644 --- a/cli/azd/pkg/extensions/manager.go +++ b/cli/azd/pkg/extensions/manager.go @@ -111,6 +111,9 @@ type FilterOptions struct { Version string // Source is used to specify the source of the extension to install Source string + // SourceConfig restricts lookup to one source that is not persisted or cached. + // It takes precedence over Source. + SourceConfig *SourceConfig // Tags is used to specify the tags of the extension to install Tags []string // Capability is used to filter extensions by capability type @@ -440,10 +443,23 @@ func (m *Manager) FindExtensions(ctx context.Context, options *FilterOptions) ([ } } + filterOptions := options + if options.SourceConfig != nil { + filterOptionsCopy := *options + filterOptionsCopy.Source = "" + filterOptions = &filterOptionsCopy + } + // Use the centralized extension filter - extensionFilter := createExtensionFilter(options) + extensionFilter := createExtensionFilter(filterOptions) - sources, err := m.getSources(ctx, sourceFilterPredicate) + var sources []Source + var err error + if options.SourceConfig != nil { + sources, err = m.createSourcesFromConfig(ctx, []*SourceConfig{options.SourceConfig}, nil) + } else { + sources, err = m.getSources(ctx, sourceFilterPredicate) + } if err != nil { return nil, fmt.Errorf("failed listing extensions: %w", err) } From 18bc0568a495226e35afd40dbf253e7b3e79bcb1 Mon Sep 17 00:00:00 2001 From: Jeffrey Chen Date: Thu, 25 Jun 2026 18:56:42 +0000 Subject: [PATCH 4/8] Polish changes --- cli/azd/CHANGELOG.md | 2 - cli/azd/cmd/extension.go | 80 +++++++++-------- cli/azd/cmd/extension_install_source_test.go | 90 +++++++++++++++----- 3 files changed, 115 insertions(+), 57 deletions(-) diff --git a/cli/azd/CHANGELOG.md b/cli/azd/CHANGELOG.md index e7ef3635362..2e227d98e57 100644 --- a/cli/azd/CHANGELOG.md +++ b/cli/azd/CHANGELOG.md @@ -4,8 +4,6 @@ ### Features Added -- [[#8581]](https://github.com/Azure/azure-dev/issues/8581) `azd extension install -s ` now accepts a registry location (URL or file path) directly for `-s/--source`, registering it as a persisted source and installing in one step. azd prompts for the source name and confirms before registering an untrusted URL; under `--no-prompt` it directs you to add the source first with `azd extension source add`. - ### Breaking Changes ### Bugs Fixed diff --git a/cli/azd/cmd/extension.go b/cli/azd/cmd/extension.go index 6a643752b79..db522b2c1b7 100644 --- a/cli/azd/cmd/extension.go +++ b/cli/azd/cmd/extension.go @@ -13,7 +13,6 @@ import ( "log" "maps" "net" - "net/url" "os" "path/filepath" "slices" @@ -1482,7 +1481,6 @@ func registerSourceFromLocation( } if kind == extensions.SourceKindUrl { - console.Message(ctx, "") confirm, err := console.Confirm(ctx, input.ConsoleOptions{ Message: fmt.Sprintf( "Register and use the extension source at %s?", @@ -1495,23 +1493,38 @@ func registerSourceFromLocation( if !confirm { return "", &internal.ErrorWithSuggestion{ Err: errors.New("extension source registration declined"), - Suggestion: "Re-run and confirm to register the source, " + - "or add it explicitly with 'azd extension source add'.", + Suggestion: fmt.Sprintf( + "Re-run and confirm to register the source, or add it explicitly with %s.", + output.WithHighLightFormat("azd extension source add"), + ), } } } - defaultName := defaultSourceName(location) - sourceName, err := console.Prompt(ctx, input.ConsoleOptions{ - Message: "Enter a name for this extension source:", - DefaultValue: defaultName, - }) - if err != nil { - return "", err - } - sourceName = strings.TrimSpace(sourceName) - if sourceName == "" { - sourceName = defaultName + var sourceName string + for { + sourceNameInput, err := console.Prompt(ctx, input.ConsoleOptions{ + Message: "Enter a name for this extension source", + }) + if err != nil { + return "", err + } + sourceName = strings.TrimSpace(sourceNameInput) + if sourceName == "" { + console.Message(ctx, output.WithErrorFormat("Extension source name cannot be empty")) + continue + } + if err := validateSourceName(sourceName); err != nil { + console.Message(ctx, output.WithErrorFormat(err.Error())) + continue + } + if _, err := sourceManager.Get(ctx, normalizeSourceKey(sourceName)); err == nil { + console.Message(ctx, output.WithErrorFormat("Extension source '%s' already exists", sourceName)) + continue + } else if !errors.Is(err, extensions.ErrSourceNotFound) { + return "", fmt.Errorf("failed to resolve extension source %q: %w", sourceName, err) + } + break } sourceConfig := &extensions.SourceConfig{ @@ -1520,6 +1533,7 @@ func registerSourceFromLocation( Location: location, } + console.Message(ctx, "") spinnerMessage := fmt.Sprintf("Registering extension source %s", output.WithHighLightFormat(sourceName)) console.ShowSpinner(ctx, spinnerMessage, input.Step) @@ -1582,6 +1596,23 @@ func absPath(path string) string { return path } +func normalizeSourceKey(name string) string { + return strings.ReplaceAll(strings.ToLower(name), " ", "-") +} + +func validateSourceName(name string) error { + if strings.Contains(name, ".") { + return errors.New("Extension source name cannot contain '.'") + } + if strings.ContainsAny(name, `/\`) { + return errors.New("Extension source name cannot contain path separators") + } + if strings.EqualFold(normalizeSourceKey(name), extensions.BundleSourceName) { + return fmt.Errorf("Extension source name '%s' is reserved", extensions.BundleSourceName) + } + return nil +} + // resolveReadOnlySourceFilter returns a temporary source config for read-only // commands when --source is a registry location. Registered names return nil. func resolveReadOnlySourceFilter( @@ -1636,25 +1667,6 @@ func inferSourceKind(location string) (extensions.SourceKind, bool) { return "", false } -// defaultSourceName derives a config-safe default extension source name from a -// registry location: the host for URLs, otherwise the file name without its -// extension. It falls back to "custom" when nothing usable can be derived. -func defaultSourceName(location string) string { - base := "" - if u, err := url.Parse(location); err == nil && u.Host != "" { - base = u.Host - } else { - base = filepath.Base(location) - base = strings.TrimSuffix(base, filepath.Ext(base)) - } - - name := normalizeBundleSourceName(base) - if name == "" { - name = "custom" - } - return name -} - // azd extension uninstall type extensionUninstallFlags struct { all bool diff --git a/cli/azd/cmd/extension_install_source_test.go b/cli/azd/cmd/extension_install_source_test.go index 5186252f9d5..f21698072aa 100644 --- a/cli/azd/cmd/extension_install_source_test.go +++ b/cli/azd/cmd/extension_install_source_test.go @@ -15,6 +15,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/extensions" "github.com/azure/azure-dev/cli/azd/pkg/input" "github.com/azure/azure-dev/cli/azd/pkg/lazy" + "github.com/azure/azure-dev/cli/azd/pkg/output" "github.com/azure/azure-dev/cli/azd/test/mocks" "github.com/azure/azure-dev/cli/azd/test/mocks/mockinput" "github.com/stretchr/testify/require" @@ -75,21 +76,6 @@ func TestInferSourceKind(t *testing.T) { }) } -func TestDefaultSourceName(t *testing.T) { - t.Parallel() - - cases := map[string]string{ - "https://example.com/registry.json": "example-com", - "https://link/to/registry.json": "link", - "/path/to/registry.json": "registry", - "./custom.json": "custom", - } - - for location, expected := range cases { - require.Equal(t, expected, defaultSourceName(location), "location %q", location) - } -} - func TestResolveSourceLocation_ExistingSourceUnchanged(t *testing.T) { t.Parallel() @@ -145,25 +131,85 @@ func TestResolveSourceLocation_FileRegistersSource(t *testing.T) { require.NoError(t, err) require.Equal(t, extensions.SourceKindFile, src.Type) require.Equal(t, registryPath, src.Location) + require.Contains(t, console.Output(), "") } -func TestResolveSourceLocation_FileUsesDefaultNameWhenBlank(t *testing.T) { +func TestResolveSourceLocation_BlankSourceNamePromptsAgain(t *testing.T) { t.Parallel() registryPath := writeRegistryFile(t) console := mockinput.NewMockConsole() - console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).Respond("") + promptCount := 0 + console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).RespondFn(func(input.ConsoleOptions) (any, error) { + promptCount++ + if promptCount == 1 { + return "", nil + } + return "local-dev", nil + }) action, _ := newBundleInstallTestAction(t) action.console = console action.flags.source = registryPath require.NoError(t, action.resolveSourceLocation(t.Context())) - require.Equal(t, "registry", action.flags.source) + require.Equal(t, "local-dev", action.flags.source) + require.Equal(t, 2, promptCount) +} - _, err := action.sourceManager.Get(t.Context(), "registry") - require.NoError(t, err) +func TestResolveSourceLocation_InvalidSourceNamePromptsAgain(t *testing.T) { + t.Parallel() + + registryPath := writeRegistryFile(t) + + console := mockinput.NewMockConsole() + promptCount := 0 + console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).RespondFn(func(input.ConsoleOptions) (any, error) { + promptCount++ + if promptCount == 1 { + return "my.registry", nil + } + return "local-dev", nil + }) + + action, _ := newBundleInstallTestAction(t) + action.console = console + action.flags.source = registryPath + + require.NoError(t, action.resolveSourceLocation(t.Context())) + require.Equal(t, "local-dev", action.flags.source) + require.Equal(t, 2, promptCount) + require.Contains(t, console.Output(), output.WithErrorFormat("Extension source name cannot contain '.'")) +} + +func TestResolveSourceLocation_ExistingSourceNamePromptsAgain(t *testing.T) { + t.Parallel() + + registryPath := writeRegistryFile(t) + + console := mockinput.NewMockConsole() + promptCount := 0 + console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).RespondFn(func(input.ConsoleOptions) (any, error) { + promptCount++ + if promptCount == 1 { + return "existing source", nil + } + return "local-dev", nil + }) + + action, _ := newBundleInstallTestAction(t) + action.console = console + action.flags.source = registryPath + require.NoError(t, action.sourceManager.Add(t.Context(), "existing-source", &extensions.SourceConfig{ + Name: "existing-source", + Type: extensions.SourceKindUrl, + Location: "https://example.com/registry.json", + })) + + require.NoError(t, action.resolveSourceLocation(t.Context())) + require.Equal(t, "local-dev", action.flags.source) + require.Equal(t, 2, promptCount) } func TestResolveSourceLocation_UrlDeclinedReturnsError(t *testing.T) { @@ -259,7 +305,9 @@ func TestResolveSourceLocation_UrlAcceptedRegistersSource(t *testing.T) { }) }) - mockContext.Console.WhenConfirm(func(input.ConsoleOptions) bool { return true }).Respond(true) + mockContext.Console.WhenConfirm(func(options input.ConsoleOptions) bool { + return options.DefaultValue == false + }).Respond(true) mockContext.Console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).Respond("example-registry") action.flags.source = "https://example.com/registry.json" From e5a867c75650fe3a0b176980df0c8c9b26daf663 Mon Sep 17 00:00:00 2001 From: Jeffrey Chen Date: Thu, 25 Jun 2026 21:18:35 +0000 Subject: [PATCH 5/8] Address feedback --- cli/azd/cmd/extension.go | 18 +++++++++------- cli/azd/cmd/extension_install_source_test.go | 7 ++++++- cli/azd/cmd/extension_source_location_test.go | 21 +++++++++++++++++++ .../docs/extensions/extension-framework.md | 2 +- cli/azd/pkg/extensions/manager.go | 17 ++++++++++++++- cli/azd/pkg/extensions/source_manager.go | 8 +++---- 6 files changed, 58 insertions(+), 15 deletions(-) diff --git a/cli/azd/cmd/extension.go b/cli/azd/cmd/extension.go index db522b2c1b7..fc41aeec6ba 100644 --- a/cli/azd/cmd/extension.go +++ b/cli/azd/cmd/extension.go @@ -15,6 +15,7 @@ import ( "net" "os" "path/filepath" + "runtime" "slices" "strings" "text/tabwriter" @@ -1518,7 +1519,7 @@ func registerSourceFromLocation( console.Message(ctx, output.WithErrorFormat(err.Error())) continue } - if _, err := sourceManager.Get(ctx, normalizeSourceKey(sourceName)); err == nil { + if _, err := sourceManager.Get(ctx, extensions.NormalizeSourceKey(sourceName)); err == nil { console.Message(ctx, output.WithErrorFormat("Extension source '%s' already exists", sourceName)) continue } else if !errors.Is(err, extensions.ErrSourceNotFound) { @@ -1582,7 +1583,12 @@ func locationsEqual(kind extensions.SourceKind, a, b string) bool { case extensions.SourceKindUrl: return strings.EqualFold(a, b) case extensions.SourceKindFile: - return absPath(a) == absPath(b) + a = filepath.Clean(absPath(a)) + b = filepath.Clean(absPath(b)) + if runtime.GOOS == "windows" { + return strings.EqualFold(a, b) + } + return a == b default: return a == b } @@ -1596,10 +1602,6 @@ func absPath(path string) string { return path } -func normalizeSourceKey(name string) string { - return strings.ReplaceAll(strings.ToLower(name), " ", "-") -} - func validateSourceName(name string) error { if strings.Contains(name, ".") { return errors.New("Extension source name cannot contain '.'") @@ -1607,7 +1609,7 @@ func validateSourceName(name string) error { if strings.ContainsAny(name, `/\`) { return errors.New("Extension source name cannot contain path separators") } - if strings.EqualFold(normalizeSourceKey(name), extensions.BundleSourceName) { + if strings.EqualFold(extensions.NormalizeSourceKey(name), extensions.BundleSourceName) { return fmt.Errorf("Extension source name '%s' is reserved", extensions.BundleSourceName) } return nil @@ -1661,7 +1663,7 @@ func inferSourceKind(location string) (extensions.SourceKind, bool) { if info, err := os.Stat(location); err == nil && !info.IsDir() { return extensions.SourceKindFile, true } - if strings.ContainsAny(location, `/\`) || strings.EqualFold(filepath.Ext(location), ".json") { + if strings.ContainsAny(location, `/\`) { return extensions.SourceKindFile, true } return "", false diff --git a/cli/azd/cmd/extension_install_source_test.go b/cli/azd/cmd/extension_install_source_test.go index f21698072aa..89c989bca59 100644 --- a/cli/azd/cmd/extension_install_source_test.go +++ b/cli/azd/cmd/extension_install_source_test.go @@ -59,11 +59,16 @@ func TestInferSourceKind(t *testing.T) { }) t.Run("JsonExtension", func(t *testing.T) { - kind, ok := inferSourceKind("missing-registry.json") + kind, ok := inferSourceKind(existing) require.True(t, ok) require.Equal(t, extensions.SourceKindFile, kind) }) + t.Run("MissingJsonSourceName", func(t *testing.T) { + _, ok := inferSourceKind("missing-registry.json") + require.False(t, ok) + }) + t.Run("PathSeparator", func(t *testing.T) { kind, ok := inferSourceKind("./some/path") require.True(t, ok) diff --git a/cli/azd/cmd/extension_source_location_test.go b/cli/azd/cmd/extension_source_location_test.go index 1e63067fe95..579905587c3 100644 --- a/cli/azd/cmd/extension_source_location_test.go +++ b/cli/azd/cmd/extension_source_location_test.go @@ -157,6 +157,27 @@ func TestExtensionList_DirectRelativeFileSource(t *testing.T) { requireLocationNotRegistered(t, sourceManager, registryPath) } +func TestExtensionList_DirectMissingFileSourceReturnsError(t *testing.T) { + t.Parallel() + + _, manager, sourceManager := newSourceLocationTestManager(t) + + var buf bytes.Buffer + action := &extensionListAction{ + flags: &extensionListFlags{source: "./missing-registry.json"}, + formatter: &output.JsonFormatter{}, + console: mocks.NewMockContext(t.Context()).Console, + writer: &buf, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.Error(t, err) + require.ErrorContains(t, err, "failed listing extensions from registry") + require.ErrorContains(t, err, "failed initializing extension source") +} + func TestExtensionShow_DirectUrlSource(t *testing.T) { t.Parallel() diff --git a/cli/azd/docs/extensions/extension-framework.md b/cli/azd/docs/extensions/extension-framework.md index 03af43d7ea1..6cb390088c1 100644 --- a/cli/azd/docs/extensions/extension-framework.md +++ b/cli/azd/docs/extensions/extension-framework.md @@ -154,7 +154,7 @@ Upgrades one or more extensions to the latest versions. - `--all` Upgrades all previously installed extensions when specified. - `-v, --version` Upgrades a specified extension to an exact version, if provided. -- `-s, --source` Specifies the source used for the upgrade. In addition to registered source names, this accepts a registry location (URL or file path). `azd` registers the location as a source, updates the extension's stored source, and rejects locations under `--no-prompt`; add the source first with `azd extension source add`. +- `-s, --source` Specifies the source used for the upgrade. In addition to registered source names, this accepts a registry location (URL or file path). `azd` registers the location as a source before resolving the extension, updates the extension's stored source after a successful upgrade, and rejects locations under `--no-prompt`; add the source first with `azd extension source add`. - `--no-dependency-upgrades` Skips upgrading dependencies declared by extension packs. ## Developing Extensions diff --git a/cli/azd/pkg/extensions/manager.go b/cli/azd/pkg/extensions/manager.go index 07f0ea23699..ccab7fc3e41 100644 --- a/cli/azd/pkg/extensions/manager.go +++ b/cli/azd/pkg/extensions/manager.go @@ -456,7 +456,22 @@ func (m *Manager) FindExtensions(ctx context.Context, options *FilterOptions) ([ var sources []Source var err error if options.SourceConfig != nil { - sources, err = m.createSourcesFromConfig(ctx, []*SourceConfig{options.SourceConfig}, nil) + source, err := m.sourceManager.CreateSource(ctx, options.SourceConfig) + if err != nil { + if schemaErr, ok := errors.AsType[*ErrUnsupportedRegistrySchema](err); ok { + return nil, &errorhandler.ErrorWithSuggestion{ + Err: schemaErr, + Message: schemaErr.Error(), + Suggestion: "Upgrade azd to the latest version to use this registry", + Links: []errorhandler.ErrorLink{{ + URL: "https://aka.ms/azd/install", + Title: "Install/upgrade azd", + }}, + } + } + return nil, fmt.Errorf("failed initializing extension source: %w", err) + } + sources = []Source{source} } else { sources, err = m.getSources(ctx, sourceFilterPredicate) } diff --git a/cli/azd/pkg/extensions/source_manager.go b/cli/azd/pkg/extensions/source_manager.go index 2d562b39e72..bda5abda667 100644 --- a/cli/azd/pkg/extensions/source_manager.go +++ b/cli/azd/pkg/extensions/source_manager.go @@ -88,7 +88,7 @@ func (sm *SourceManager) Get(ctx context.Context, name string) (*SourceConfig, e // Add adds a new extension source. func (sm *SourceManager) Add(ctx context.Context, name string, source *SourceConfig) error { - newKey := normalizeKey(name) + newKey := NormalizeSourceKey(name) if strings.EqualFold(newKey, BundleSourceName) { return fmt.Errorf( @@ -113,7 +113,7 @@ func (sm *SourceManager) Add(ctx context.Context, name string, source *SourceCon // Remove removes an extension source. func (sm *SourceManager) Remove(ctx context.Context, name string) error { - name = normalizeKey(name) + name = NormalizeSourceKey(name) _, err := sm.Get(ctx, name) if err != nil && errors.Is(err, ErrSourceNotFound) { @@ -247,8 +247,8 @@ func (sm *SourceManager) addInternal(source *SourceConfig) error { return nil } -// normalizeKey normalizes a key for use in the configuration. -func normalizeKey(key string) string { +// NormalizeSourceKey normalizes an extension source name for use in configuration keys. +func NormalizeSourceKey(key string) string { key = strings.ToLower(key) key = strings.ReplaceAll(key, " ", "-") From f94aba9d8d247869f4ee3ff177fd85bcf5400be4 Mon Sep 17 00:00:00 2001 From: Jeffrey Chen Date: Thu, 25 Jun 2026 21:38:15 +0000 Subject: [PATCH 6/8] Address feedback --- cli/azd/cmd/extension.go | 67 +++++++++++++++---- cli/azd/cmd/extension_install_source_test.go | 15 +++++ cli/azd/cmd/extension_source_location_test.go | 58 ++++++++++++++++ 3 files changed, 127 insertions(+), 13 deletions(-) diff --git a/cli/azd/cmd/extension.go b/cli/azd/cmd/extension.go index fc41aeec6ba..978bbe9978d 100644 --- a/cli/azd/cmd/extension.go +++ b/cli/azd/cmd/extension.go @@ -275,9 +275,14 @@ func (a *extensionListAction) Run(ctx context.Context) (*actions.ActionResult, e options.SourceConfig = sourceConfig options.Source = "" } else if options.Source != "" { - if _, err := a.sourceManager.Get(ctx, options.Source); err != nil { - return nil, fmt.Errorf("extension source '%s' not found: %w", options.Source, err) + resolvedSource, ok, err := resolveRegisteredSourceName(ctx, a.sourceManager, options.Source) + if err != nil { + return nil, err + } + if !ok { + return nil, fmt.Errorf("extension source '%s' not found: %w", options.Source, extensions.ErrSourceNotFound) } + options.Source = resolvedSource } registryExtensions, err := a.extensionManager.FindExtensions(ctx, options) @@ -739,6 +744,14 @@ func (a *extensionShowAction) Run(ctx context.Context) (*actions.ActionResult, e if sourceConfig != nil { filterOptions.SourceConfig = sourceConfig filterOptions.Source = "" + } else if filterOptions.Source != "" { + resolvedSource, ok, err := resolveRegisteredSourceName(ctx, a.sourceManager, filterOptions.Source) + if err != nil { + return nil, err + } + if ok { + filterOptions.Source = resolvedSource + } } extensionMatches, err := a.extensionManager.FindExtensions(ctx, filterOptions) @@ -1439,13 +1452,12 @@ func registerSourceFromLocation( return source, nil } - // If the value already names a registered source, keep current behavior. - _, err := sourceManager.Get(ctx, source) - if err == nil { - return source, nil + resolvedSource, ok, err := resolveRegisteredSourceName(ctx, sourceManager, source) + if err != nil { + return "", err } - if !errors.Is(err, extensions.ErrSourceNotFound) { - return "", fmt.Errorf("failed to resolve extension source %q: %w", source, err) + if ok { + return resolvedSource, nil } location := source @@ -1557,6 +1569,35 @@ func registerSourceFromLocation( return sourceConfig.Name, nil } +func resolveRegisteredSourceName( + ctx context.Context, + sourceManager *extensions.SourceManager, + source string, +) (string, bool, error) { + _, err := sourceManager.Get(ctx, source) + if err == nil { + return source, true, nil + } + if !errors.Is(err, extensions.ErrSourceNotFound) { + return "", false, fmt.Errorf("failed to resolve extension source %q: %w", source, err) + } + + normalizedSource := extensions.NormalizeSourceKey(source) + if normalizedSource == source { + return "", false, nil + } + + _, err = sourceManager.Get(ctx, normalizedSource) + if err == nil { + return normalizedSource, true, nil + } + if !errors.Is(err, extensions.ErrSourceNotFound) { + return "", false, fmt.Errorf("failed to resolve extension source %q: %w", source, err) + } + + return "", false, nil +} + // findSourceByLocation returns the registered source for location, if any. func findSourceByLocation( ctx context.Context, @@ -1626,12 +1667,12 @@ func resolveReadOnlySourceFilter( return nil, nil } - _, err := sourceManager.Get(ctx, source) - if err == nil { - return nil, nil + _, ok, err := resolveRegisteredSourceName(ctx, sourceManager, source) + if err != nil { + return nil, err } - if !errors.Is(err, extensions.ErrSourceNotFound) { - return nil, fmt.Errorf("failed to resolve extension source %q: %w", source, err) + if ok { + return nil, nil } kind, ok := inferSourceKind(source) diff --git a/cli/azd/cmd/extension_install_source_test.go b/cli/azd/cmd/extension_install_source_test.go index 89c989bca59..f911fa675fc 100644 --- a/cli/azd/cmd/extension_install_source_test.go +++ b/cli/azd/cmd/extension_install_source_test.go @@ -96,6 +96,21 @@ func TestResolveSourceLocation_ExistingSourceUnchanged(t *testing.T) { require.Equal(t, "my-source", action.flags.source) } +func TestResolveSourceLocation_NormalizedExistingSourceUsed(t *testing.T) { + t.Parallel() + + action, _ := newBundleInstallTestAction(t) + require.NoError(t, action.sourceManager.Add(t.Context(), "my-source", &extensions.SourceConfig{ + Name: "my-source", + Type: extensions.SourceKindUrl, + Location: "https://example.com/registry.json", + })) + + action.flags.source = "my source" + require.NoError(t, action.resolveSourceLocation(t.Context())) + require.Equal(t, "my-source", action.flags.source) +} + func TestResolveSourceLocation_PlainNameUnchanged(t *testing.T) { t.Parallel() diff --git a/cli/azd/cmd/extension_source_location_test.go b/cli/azd/cmd/extension_source_location_test.go index 579905587c3..3231f8016c3 100644 --- a/cli/azd/cmd/extension_source_location_test.go +++ b/cli/azd/cmd/extension_source_location_test.go @@ -131,6 +131,36 @@ func TestExtensionList_UnknownRegisteredNameErrors(t *testing.T) { require.Error(t, err) } +func TestExtensionList_NormalizedRegisteredSourceName(t *testing.T) { + t.Parallel() + + mockContext, manager, sourceManager := newSourceLocationTestManager(t) + stubRegistryHTTP(mockContext) + require.NoError(t, sourceManager.Add(t.Context(), "my-source", &extensions.SourceConfig{ + Name: "my-source", + Type: extensions.SourceKindUrl, + Location: sourceLocationRegistryURL, + })) + + var buf bytes.Buffer + action := &extensionListAction{ + flags: &extensionListFlags{source: "my source"}, + formatter: &output.JsonFormatter{}, + console: mockContext.Console, + writer: &buf, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.NoError(t, err) + + var rows []extensionListItem + require.NoError(t, json.Unmarshal(buf.Bytes(), &rows)) + require.Len(t, rows, 1) + require.Equal(t, "my-source", rows[0].Source) +} + func TestExtensionList_DirectRelativeFileSource(t *testing.T) { registryPath := writeRegistryFile(t) t.Chdir(filepath.Dir(registryPath)) @@ -203,6 +233,34 @@ func TestExtensionShow_DirectUrlSource(t *testing.T) { requireLocationNotRegistered(t, sourceManager, sourceLocationRegistryURL) } +func TestExtensionShow_NormalizedRegisteredSourceName(t *testing.T) { + t.Parallel() + + mockContext, manager, sourceManager := newSourceLocationTestManager(t) + stubRegistryHTTP(mockContext) + require.NoError(t, sourceManager.Add(t.Context(), "my-source", &extensions.SourceConfig{ + Name: "my-source", + Type: extensions.SourceKindUrl, + Location: sourceLocationRegistryURL, + })) + + action := &extensionShowAction{ + args: []string{"test.ext"}, + flags: &extensionShowFlags{ + source: "my source", + global: &internal.GlobalCommandOptions{}, + }, + console: mockContext.Console, + formatter: &output.NoneFormatter{}, + writer: &bytes.Buffer{}, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.NoError(t, err) +} + func TestExtensionUpgrade_UrlSourceRegistersSource(t *testing.T) { t.Parallel() From 69befef01320fed68dc03475b38989d9b0237016 Mon Sep 17 00:00:00 2001 From: Jeffrey Chen Date: Fri, 26 Jun 2026 00:29:11 +0000 Subject: [PATCH 7/8] Address feedback --- cli/azd/cmd/extension.go | 122 ++++++++++-------- cli/azd/cmd/extension_install_source_test.go | 37 ++++++ cli/azd/cmd/extension_source_location_test.go | 23 ++++ cli/azd/cmd/telemetry_test.go | 47 ++++--- cli/azd/internal/tracing/fields/fields.go | 6 + cli/azd/pkg/extensions/manager.go | 22 +--- cli/azd/pkg/extensions/registry_version.go | 15 +++ docs/reference/telemetry-data.md | 1 + .../metrics-audit/feature-telemetry-matrix.md | 4 +- docs/specs/metrics-audit/telemetry-schema.md | 1 + 10 files changed, 187 insertions(+), 91 deletions(-) diff --git a/cli/azd/cmd/extension.go b/cli/azd/cmd/extension.go index 978bbe9978d..fc027b1312b 100644 --- a/cli/azd/cmd/extension.go +++ b/cli/azd/cmd/extension.go @@ -13,6 +13,7 @@ import ( "log" "maps" "net" + "net/url" "os" "path/filepath" "runtime" @@ -167,6 +168,9 @@ Use --output json for a structured report of all upgrade results.`, Command: &cobra.Command{ Use: "add", Short: "Add an extension source with the specified name", + Long: "Add an extension source with the specified name.\n\n" + + "`azd extension install --source` and `azd extension upgrade --source` also accept " + + "a registry URL or file path directly.", }, ActionResolver: newExtensionSourceAddAction, FlagsResolver: newExtensionSourceAddFlags, @@ -262,27 +266,22 @@ type extensionListItem struct { } func (a *extensionListAction) Run(ctx context.Context) (*actions.ActionResult, error) { + tracing.SetUsageAttributes(fields.ExtensionSourceKind.String(sourceArgKind(a.flags.source))) options := &extensions.FilterOptions{ Source: a.flags.source, Tags: a.flags.tags, } - sourceConfig, err := resolveReadOnlySourceFilter(ctx, a.sourceManager, a.flags.source) + sourceFilter, err := resolveSourceFilter(ctx, a.sourceManager, a.flags.source) if err != nil { return nil, err } - if sourceConfig != nil { - options.SourceConfig = sourceConfig + options.Source = sourceFilter.source + if sourceFilter.config != nil { + options.SourceConfig = sourceFilter.config options.Source = "" - } else if options.Source != "" { - resolvedSource, ok, err := resolveRegisteredSourceName(ctx, a.sourceManager, options.Source) - if err != nil { - return nil, err - } - if !ok { - return nil, fmt.Errorf("extension source '%s' not found: %w", options.Source, extensions.ErrSourceNotFound) - } - options.Source = resolvedSource + } else if options.Source != "" && !sourceFilter.registered { + return nil, fmt.Errorf("extension source '%s' not found: %w", a.flags.source, extensions.ErrSourceNotFound) } registryExtensions, err := a.extensionManager.FindExtensions(ctx, options) @@ -719,6 +718,7 @@ func (t *extensionShowItem) Display(writer io.Writer) error { } func (a *extensionShowAction) Run(ctx context.Context) (*actions.ActionResult, error) { + tracing.SetUsageAttributes(fields.ExtensionSourceKind.String(sourceArgKind(a.flags.source))) if len(a.args) == 0 { return nil, &internal.ErrorWithSuggestion{ Err: internal.ErrNoArgsProvided, @@ -737,21 +737,17 @@ func (a *extensionShowAction) Run(ctx context.Context) (*actions.ActionResult, e Id: extensionId, } - sourceConfig, err := resolveReadOnlySourceFilter(ctx, a.sourceManager, a.flags.source) + sourceFilter, err := resolveSourceFilter(ctx, a.sourceManager, a.flags.source) if err != nil { return nil, err } - if sourceConfig != nil { - filterOptions.SourceConfig = sourceConfig + filterOptions.Source = sourceFilter.source + if sourceFilter.config != nil { + filterOptions.SourceConfig = sourceFilter.config filterOptions.Source = "" - } else if filterOptions.Source != "" { - resolvedSource, ok, err := resolveRegisteredSourceName(ctx, a.sourceManager, filterOptions.Source) - if err != nil { - return nil, err - } - if ok { - filterOptions.Source = resolvedSource - } + } else if filterOptions.Source != "" && !sourceFilter.registered { + return nil, fmt.Errorf( + "extension source '%s' not found: %w", a.flags.source, extensions.ErrSourceNotFound) } extensionMatches, err := a.extensionManager.FindExtensions(ctx, filterOptions) @@ -863,6 +859,7 @@ func newExtensionInstallAction( } func (a *extensionInstallAction) Run(ctx context.Context) (*actions.ActionResult, error) { + sourceKind := sourceArgKind(a.flags.source) a.console.MessageUxItem(ctx, &ux.MessageTitle{ Title: "Install an azd extension (azd extension install)", TitleNote: "Installs the specified extension onto the local machine", @@ -879,6 +876,7 @@ func (a *extensionInstallAction) Run(ctx context.Context) (*actions.ActionResult } defer a.cleanupBundleInstall(ctx) } + tracing.SetUsageAttributes(fields.ExtensionSourceKind.String(sourceKind)) extensionIds := a.args if len(extensionIds) == 0 { @@ -1452,25 +1450,16 @@ func registerSourceFromLocation( return source, nil } - resolvedSource, ok, err := resolveRegisteredSourceName(ctx, sourceManager, source) + sourceFilter, err := resolveSourceFilter(ctx, sourceManager, source) if err != nil { return "", err } - if ok { - return resolvedSource, nil - } - - location := source - kind, ok := inferSourceKind(location) - if !ok { - return source, nil + if sourceFilter.config == nil { + return sourceFilter.source, nil } - if kind == extensions.SourceKindFile { - if abs, err := filepath.Abs(location); err == nil { - location = abs - } - } + location := sourceFilter.config.Location + kind := sourceFilter.config.Type existing, err := findSourceByLocation(ctx, sourceManager, kind, location) if err != nil { @@ -1552,6 +1541,9 @@ func registerSourceFromLocation( if _, err := sourceManager.CreateSource(ctx, sourceConfig); err != nil { console.StopSpinner(ctx, spinnerMessage, input.StepFailed) + if schemaErr, ok := errors.AsType[*extensions.ErrUnsupportedRegistrySchema](err); ok { + return "", extensions.NewUnsupportedRegistrySchemaError(schemaErr) + } return "", fmt.Errorf("failed to validate extension source: %w", err) } @@ -1622,7 +1614,7 @@ func findSourceByLocation( func locationsEqual(kind extensions.SourceKind, a, b string) bool { switch kind { case extensions.SourceKindUrl: - return strings.EqualFold(a, b) + return strings.EqualFold(normalizeUrlLocation(a), normalizeUrlLocation(b)) case extensions.SourceKindFile: a = filepath.Clean(absPath(a)) b = filepath.Clean(absPath(b)) @@ -1643,6 +1635,17 @@ func absPath(path string) string { return path } +func normalizeUrlLocation(location string) string { + parsed, err := url.Parse(location) + if err != nil { + return location + } + parsed.Scheme = strings.ToLower(parsed.Scheme) + parsed.Host = strings.ToLower(parsed.Host) + parsed.Path = strings.TrimRight(parsed.Path, "/") + return parsed.String() +} + func validateSourceName(name string) error { if strings.Contains(name, ".") { return errors.New("Extension source name cannot contain '.'") @@ -1656,28 +1659,42 @@ func validateSourceName(name string) error { return nil } -// resolveReadOnlySourceFilter returns a temporary source config for read-only -// commands when --source is a registry location. Registered names return nil. -func resolveReadOnlySourceFilter( +func sourceArgKind(source string) string { + if source == "" { + return "none" + } + if _, ok := inferSourceKind(source); ok { + return "location" + } + return "registered" +} + +type sourceFilterResolution struct { + source string + config *extensions.SourceConfig + registered bool +} + +func resolveSourceFilter( ctx context.Context, sourceManager *extensions.SourceManager, source string, -) (*extensions.SourceConfig, error) { +) (sourceFilterResolution, error) { if source == "" { - return nil, nil + return sourceFilterResolution{}, nil } - _, ok, err := resolveRegisteredSourceName(ctx, sourceManager, source) + resolvedSource, ok, err := resolveRegisteredSourceName(ctx, sourceManager, source) if err != nil { - return nil, err + return sourceFilterResolution{}, err } if ok { - return nil, nil + return sourceFilterResolution{source: resolvedSource, registered: true}, nil } kind, ok := inferSourceKind(source) if !ok { - return nil, nil + return sourceFilterResolution{source: source}, nil } location := source @@ -1685,10 +1702,12 @@ func resolveReadOnlySourceFilter( location = absPath(location) } - return &extensions.SourceConfig{ - Name: location, - Type: kind, - Location: location, + return sourceFilterResolution{ + config: &extensions.SourceConfig{ + Name: location, + Type: kind, + Location: location, + }, }, nil } @@ -1872,6 +1891,7 @@ func newExtensionUpgradeAction( func (a *extensionUpgradeAction) Run( ctx context.Context, ) (*actions.ActionResult, error) { + tracing.SetUsageAttributes(fields.ExtensionSourceKind.String(sourceArgKind(a.flags.source))) if len(a.args) > 0 && a.flags.all { return nil, &internal.ErrorWithSuggestion{ Err: fmt.Errorf( diff --git a/cli/azd/cmd/extension_install_source_test.go b/cli/azd/cmd/extension_install_source_test.go index f911fa675fc..7934aaedbb2 100644 --- a/cli/azd/cmd/extension_install_source_test.go +++ b/cli/azd/cmd/extension_install_source_test.go @@ -275,6 +275,21 @@ func TestResolveSourceLocation_ExistingUrlLocationReused(t *testing.T) { require.Equal(t, 1, matches) } +func TestResolveSourceLocation_ExistingUrlLocationReusedWithTrailingSlash(t *testing.T) { + t.Parallel() + + action, _ := newBundleInstallTestAction(t) + require.NoError(t, action.sourceManager.Add(t.Context(), "myreg", &extensions.SourceConfig{ + Name: "myreg", + Type: extensions.SourceKindUrl, + Location: "https://example.com/registry.json", + })) + + action.flags.source = "https://example.com/registry.json/" + require.NoError(t, action.resolveSourceLocation(t.Context())) + require.Equal(t, "myreg", action.flags.source) +} + func TestResolveSourceLocation_ExistingFileLocationReusedFromRelativePath(t *testing.T) { registryPath := writeRegistryFile(t) dir := filepath.Dir(registryPath) @@ -340,6 +355,28 @@ func TestResolveSourceLocation_UrlAcceptedRegistersSource(t *testing.T) { require.Equal(t, "https://example.com/registry.json", src.Location) } +func TestResolveSourceLocation_UnsupportedSchemaReturnsSuggestion(t *testing.T) { + t.Parallel() + + action, mockContext := newInstallSourceTestAction(t) + mockContext.HttpClient.When(func(req *http.Request) bool { + return req.URL.String() == "https://example.com/registry.json" + }).RespondFn(func(req *http.Request) (*http.Response, error) { + return mocks.CreateHttpResponseWithBody(req, http.StatusOK, extensions.Registry{ + SchemaVersion: "2.0", + }) + }) + + mockContext.Console.WhenConfirm(func(input.ConsoleOptions) bool { return true }).Respond(true) + mockContext.Console.WhenPrompt(func(input.ConsoleOptions) bool { return true }).Respond("example-registry") + + action.flags.source = "https://example.com/registry.json" + err := action.resolveSourceLocation(t.Context()) + require.Error(t, err) + require.ErrorAs(t, err, new(*internal.ErrorWithSuggestion)) + require.ErrorAs(t, err, new(*extensions.ErrUnsupportedRegistrySchema)) +} + func TestResolveSourceLocation_NoPromptFileDirectsToSourceAdd(t *testing.T) { t.Parallel() diff --git a/cli/azd/cmd/extension_source_location_test.go b/cli/azd/cmd/extension_source_location_test.go index 3231f8016c3..2e873db4732 100644 --- a/cli/azd/cmd/extension_source_location_test.go +++ b/cli/azd/cmd/extension_source_location_test.go @@ -261,6 +261,29 @@ func TestExtensionShow_NormalizedRegisteredSourceName(t *testing.T) { require.NoError(t, err) } +func TestExtensionShow_UnknownRegisteredNameErrors(t *testing.T) { + t.Parallel() + + _, manager, sourceManager := newSourceLocationTestManager(t) + + action := &extensionShowAction{ + args: []string{"test.ext"}, + flags: &extensionShowFlags{ + source: "not-a-registered-source", + global: &internal.GlobalCommandOptions{}, + }, + console: mocks.NewMockContext(t.Context()).Console, + formatter: &output.NoneFormatter{}, + writer: &bytes.Buffer{}, + sourceManager: sourceManager, + extensionManager: manager, + } + + _, err := action.Run(t.Context()) + require.Error(t, err) + require.ErrorIs(t, err, extensions.ErrSourceNotFound) +} + func TestExtensionUpgrade_UrlSourceRegistersSource(t *testing.T) { t.Parallel() diff --git a/cli/azd/cmd/telemetry_test.go b/cli/azd/cmd/telemetry_test.go index 85703f21163..1cac149d5a9 100644 --- a/cli/azd/cmd/telemetry_test.go +++ b/cli/azd/cmd/telemetry_test.go @@ -167,6 +167,13 @@ func TestTelemetryFieldConstants(t *testing.T) { kvUpdates := fields.ToolCheckUpdatesAvailableKey.Int(3) require.Equal(t, "tool.check.updates_available", string(kvUpdates.Key)) }) + + t.Run("ExtensionFields", func(t *testing.T) { + t.Parallel() + kv := fields.ExtensionSourceKind.String("location") + require.Equal(t, "extension.source.kind", string(kv.Key)) + require.Equal(t, "location", kv.Value.AsString()) + }) } // TestCommandTelemetryCoverage ensures every user-facing command is explicitly categorized @@ -189,24 +196,28 @@ func TestCommandTelemetryCoverage(t *testing.T) { // When adding a command here, ensure the command's action sets at least one // command-specific attribute (e.g., auth.method, config.operation, env.operation). commandsWithSpecificTelemetry := []string{ - "auth login", // auth.method - "build", // (via hooks middleware) - "deploy", // infra.provider, service attributes (via hooks middleware) - "down", // infra.provider (via hooks middleware) - "env list", // env.count - "hooks run", // hooks.name, hooks.type - "infra generate", // infra.provider - "init", // init.method, appinit.* fields - "package", // (via hooks middleware) - "pipeline config", // pipeline.provider, pipeline.auth - "provision", // infra.provider (via hooks middleware) - "restore", // (via hooks middleware) - "tool check", // tool.check.updates_available - "tool install", // tool.id(s), tool.dry_run, tool.install.* aggregate + per-tool fields - "tool show", // tool.id - "tool upgrade", // tool.id(s), tool.dry_run, tool.install.* aggregate + tool.upgrade.* versions - "up", // infra.provider (via hooks middleware, composes provision+deploy) - "update", // update.* fields + "auth login", // auth.method + "build", // (via hooks middleware) + "deploy", // infra.provider, service attributes (via hooks middleware) + "down", // infra.provider (via hooks middleware) + "env list", // env.count + "extension install", // extension.source.kind + "extension list", // extension.source.kind + "extension show", // extension.source.kind + "extension upgrade", // extension.source.kind + extension upgrade spans + "hooks run", // hooks.name, hooks.type + "infra generate", // infra.provider + "init", // init.method, appinit.* fields + "package", // (via hooks middleware) + "pipeline config", // pipeline.provider, pipeline.auth + "provision", // infra.provider (via hooks middleware) + "restore", // (via hooks middleware) + "tool check", // tool.check.updates_available + "tool install", // tool.id(s), tool.dry_run, tool.install.* aggregate + per-tool fields + "tool show", // tool.id + "tool upgrade", // tool.id(s), tool.dry_run, tool.install.* aggregate + tool.upgrade.* versions + "up", // infra.provider (via hooks middleware, composes provision+deploy) + "update", // update.* fields } // Commands that rely ONLY on global middleware telemetry (command name, flags, diff --git a/cli/azd/internal/tracing/fields/fields.go b/cli/azd/internal/tracing/fields/fields.go index a05edf275dd..429f158b3a7 100644 --- a/cli/azd/internal/tracing/fields/fields.go +++ b/cli/azd/internal/tracing/fields/fields.go @@ -1154,6 +1154,12 @@ var ( Classification: SystemMetadata, Purpose: FeatureInsight, } + // ExtensionSourceKind is the kind of --source argument: none, registered, or location. + ExtensionSourceKind = AttributeKey{ + Key: attribute.Key("extension.source.kind"), + Classification: SystemMetadata, + Purpose: FeatureInsight, + } // ExtensionSourceFrom is the registry source before a promotion. ExtensionSourceFrom = AttributeKey{ Key: attribute.Key("extension.source.from"), diff --git a/cli/azd/pkg/extensions/manager.go b/cli/azd/pkg/extensions/manager.go index ccab7fc3e41..2eb579e118d 100644 --- a/cli/azd/pkg/extensions/manager.go +++ b/cli/azd/pkg/extensions/manager.go @@ -30,7 +30,6 @@ import ( "github.com/azure/azure-dev/cli/azd/internal/tracing/events" "github.com/azure/azure-dev/cli/azd/internal/tracing/fields" "github.com/azure/azure-dev/cli/azd/pkg/config" - "github.com/azure/azure-dev/cli/azd/pkg/errorhandler" "github.com/azure/azure-dev/cli/azd/pkg/lazy" "github.com/azure/azure-dev/cli/azd/pkg/osutil" "github.com/azure/azure-dev/cli/azd/pkg/output" @@ -459,15 +458,7 @@ func (m *Manager) FindExtensions(ctx context.Context, options *FilterOptions) ([ source, err := m.sourceManager.CreateSource(ctx, options.SourceConfig) if err != nil { if schemaErr, ok := errors.AsType[*ErrUnsupportedRegistrySchema](err); ok { - return nil, &errorhandler.ErrorWithSuggestion{ - Err: schemaErr, - Message: schemaErr.Error(), - Suggestion: "Upgrade azd to the latest version to use this registry", - Links: []errorhandler.ErrorLink{{ - URL: "https://aka.ms/azd/install", - Title: "Install/upgrade azd", - }}, - } + return nil, NewUnsupportedRegistrySchemaError(schemaErr) } return nil, fmt.Errorf("failed initializing extension source: %w", err) } @@ -1303,16 +1294,7 @@ func (tm *Manager) createSourcesFromConfig( // Only hard-fail when every source had an incompatible schema and // no usable sources remain. if len(sources) == 0 && len(schemaErrors) > 0 { - return nil, &errorhandler.ErrorWithSuggestion{ - Err: schemaErrors[0], - Message: schemaErrors[0].Error(), - Suggestion: "Upgrade azd to the latest version " + - "to use this registry", - Links: []errorhandler.ErrorLink{{ - URL: "https://aka.ms/azd/install", - Title: "Install/upgrade azd", - }}, - } + return nil, NewUnsupportedRegistrySchemaError(schemaErrors[0]) } return sources, nil diff --git a/cli/azd/pkg/extensions/registry_version.go b/cli/azd/pkg/extensions/registry_version.go index 70167746c22..5b4a53603b5 100644 --- a/cli/azd/pkg/extensions/registry_version.go +++ b/cli/azd/pkg/extensions/registry_version.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/Masterminds/semver/v3" + "github.com/azure/azure-dev/cli/azd/pkg/errorhandler" ) // ErrUnsupportedRegistrySchema is returned when the registry schema version @@ -26,6 +27,20 @@ func (e *ErrUnsupportedRegistrySchema) Error() string { ) } +// NewUnsupportedRegistrySchemaError wraps an ErrUnsupportedRegistrySchema in an +// ErrorWithSuggestion that guides the user to upgrade azd. +func NewUnsupportedRegistrySchemaError(schemaErr *ErrUnsupportedRegistrySchema) error { + return &errorhandler.ErrorWithSuggestion{ + Err: schemaErr, + Message: schemaErr.Error(), + Suggestion: "Upgrade azd to the latest version to use this registry", + Links: []errorhandler.ErrorLink{{ + URL: "https://aka.ms/azd/install", + Title: "Install/upgrade azd", + }}, + } +} + // CheckRegistrySchemaVersion validates that the given schema version // is compatible with this version of azd. // diff --git a/docs/reference/telemetry-data.md b/docs/reference/telemetry-data.md index 3d66f6155af..d209e6c5338 100644 --- a/docs/reference/telemetry-data.md +++ b/docs/reference/telemetry-data.md @@ -441,6 +441,7 @@ Emitted on `azd provision` / `azd up` to measure adoption and safety of `infra.l | `extension.version.from` | string | Version before an upgrade or promotion (`ext.upgrade`, `ext.promote`) | | `extension.version.to` | string | Version after an upgrade or promotion (`ext.upgrade`, `ext.promote`) | | `extension.source` | string | Registry source used for an upgrade (`ext.upgrade`) | +| `extension.source.kind` | string | Kind of `--source` argument: `none`, `registered`, or `location` (`azd extension list`, `show`, `install`, `upgrade`) | | `extension.source.from` | string | Registry source before a promotion (`ext.promote`) | | `extension.source.to` | string | Registry source after a promotion (`ext.promote`) | | `extension.upgrade.duration_ms` | measurement | Duration (ms) of a single upgrade (`ext.upgrade`) | diff --git a/docs/specs/metrics-audit/feature-telemetry-matrix.md b/docs/specs/metrics-audit/feature-telemetry-matrix.md index 741622751ba..4b233b3bd5f 100644 --- a/docs/specs/metrics-audit/feature-telemetry-matrix.md +++ b/docs/specs/metrics-audit/feature-telemetry-matrix.md @@ -30,7 +30,7 @@ These commands emit attributes or events beyond the global middleware span. |---------|---------------------|-------| | `init` | `init.method` (template / app / project / environment / copilot), `appinit.detected.databases`, `appinit.detected.services`, `appinit.confirmed.databases`, `appinit.confirmed.services`, `appinit.modify_add.count`, `appinit.modify_remove.count`, `appinit.lastStep` | Comprehensive coverage via `SetUsageAttributes` and `repository/app_init.go` | | `update` | `update.installMethod`, `update.channel`, `update.fromVersion`, `update.toVersion`, `update.result` | Result codes cover success, failure, and skip reasons | -| Extensions (dynamic) | `extension.id`, `extension.version`, `extension.version.from`, `extension.version.to`, `extension.source`, `extension.source.from`, `extension.source.to`, `extension.dependency_of`, `extension.dependency_upgrade_count`, `extension.upgrade.outcome`, `extension.upgrade.duration_ms` + trace-context propagation to child process | Covers `ext.run`, `ext.install`, `ext.upgrade`, `ext.promote` events; upgrade/promote spans set source and dependency attributes | +| Extensions (dynamic) | `extension.id`, `extension.version`, `extension.version.from`, `extension.version.to`, `extension.source`, `extension.source.kind`, `extension.source.from`, `extension.source.to`, `extension.dependency_of`, `extension.dependency_upgrade_count`, `extension.upgrade.outcome`, `extension.upgrade.duration_ms` + trace-context propagation to child process | Covers `ext.run`, `ext.install`, `ext.upgrade`, `ext.promote` events; `extension.source.kind` distinguishes no source, registered source, and direct location usage for extension list/show/install/upgrade | | `mcp start` | Per-tool spans via `tracing.Start` with `mcp.client.name`, `mcp.client.version` | MCP event prefix `mcp.*` | | `tool install` / `tool upgrade` / `tool check` / `tool list` / `tool show` | `tool.id`, `tool.ids`, `tool.dry_run`, `tool.install.strategy`, `tool.install.success`, `tool.install.success_count`, `tool.install.failure_count`, `tool.install.failed_ids`, `tool.install.duration_ms`, `tool.upgrade.from_version`, `tool.upgrade.to_version`, `tool.check.updates_available` | Comprehensive coverage in `cli/azd/cmd/tool.go`; install/upgrade emit `tools.pack.build` spans for pack-based tools | | `copilot` (agent) | `copilot.initialize` event (model + reasoning config), `copilot.session` event (session create/resume) | Emitted from `internal/agent/copilot_agent.go`; covers the experimental copilot agent surface | @@ -87,7 +87,7 @@ These commands emit attributes or events beyond the global middleware span. | **Copilot Consent** | | | | | | | `copilot consent` | `list`, `revoke`, `grant` | ✅ | ❌ | ❌ | Low priority | | **Extension Management** | | | | | | -| `extension` | `list`, `show`, `install`, `uninstall`, `upgrade` | ✅ | ✅ | ✅ | Covered by `extension.*` fields and `ext.install`, `ext.upgrade`, `ext.promote` events | +| `extension` | `list`, `show`, `install`, `uninstall`, `upgrade` | ✅ | ✅ | ✅ | Covered by `extension.*` fields and `ext.install`, `ext.upgrade`, `ext.promote` events; `extension.source.kind` tracks `--source` argument kind for list/show/install/upgrade | | `extension source` | `list`, `add`, `remove`, `validate` | ✅ | ❌ | ❌ | Subcommand name in the global span captures the operation; `extension.source*` attributes are recorded by `extension upgrade` / `extension promote`, not by this subcommand | | **Init** | | | | | | | `init` | — | ✅ | ✅ | ✅ | Comprehensive coverage via `appinit.*` fields | diff --git a/docs/specs/metrics-audit/telemetry-schema.md b/docs/specs/metrics-audit/telemetry-schema.md index eb1983516bd..6f61eb2fd3b 100644 --- a/docs/specs/metrics-audit/telemetry-schema.md +++ b/docs/specs/metrics-audit/telemetry-schema.md @@ -207,6 +207,7 @@ not emitted by azd spans. | Extension version from | `extension.version.from` | SystemMetadata | FeatureInsight | Installed version before an upgrade | | Extension version to | `extension.version.to` | SystemMetadata | FeatureInsight | Target version after an upgrade | | Extension source | `extension.source` | SystemMetadata | FeatureInsight | Registry source used for the upgrade | +| Extension source kind | `extension.source.kind` | SystemMetadata | FeatureInsight | Allowed values: `none`, `registered`, `location` | | Extension source from | `extension.source.from` | SystemMetadata | FeatureInsight | Registry source before a promotion | | Extension source to | `extension.source.to` | SystemMetadata | FeatureInsight | Registry source after a promotion | | Upgrade duration | `extension.upgrade.duration_ms` | SystemMetadata | PerformanceAndHealth | **Measurement** — time in ms for one upgrade | From 8a3646c18e55c033132acec29aa05bb4fc9410fa Mon Sep 17 00:00:00 2001 From: Jeffrey Chen Date: Fri, 26 Jun 2026 20:47:32 +0000 Subject: [PATCH 8/8] Improve test coverage --- cli/azd/pkg/extensions/manager_test.go | 168 ++++++++++++++++++ .../pkg/extensions/registry_version_test.go | 18 ++ cli/azd/pkg/extensions/source_manager_test.go | 7 + 3 files changed, 193 insertions(+) diff --git a/cli/azd/pkg/extensions/manager_test.go b/cli/azd/pkg/extensions/manager_test.go index 1b978c475a7..468305d18dc 100644 --- a/cli/azd/pkg/extensions/manager_test.go +++ b/cli/azd/pkg/extensions/manager_test.go @@ -21,6 +21,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/azure/azure-dev/cli/azd/pkg/errorhandler" "github.com/azure/azure-dev/cli/azd/pkg/exec" "github.com/azure/azure-dev/cli/azd/pkg/lazy" "github.com/azure/azure-dev/cli/azd/pkg/osutil" @@ -1015,6 +1016,173 @@ func Test_FindExtensions_MultipleMatches_ErrorHandling(t *testing.T) { require.True(t, sourceNames["source2"]) } +func Test_FindExtensions_SourceConfigDirectSource(t *testing.T) { + t.Parallel() + + registryPath := writeExtensionRegistryFile(t, Registry{ + SchemaVersion: CurrentRegistrySchemaVersion, + Extensions: []*ExtensionMetadata{ + { + Id: "direct.extension", + DisplayName: "Direct Extension", + Versions: []ExtensionVersion{ + {Version: "1.0.0"}, + }, + }, + }, + }) + + manager := newTestManager(t) + + extensions, err := manager.FindExtensions(t.Context(), &FilterOptions{ + Id: "direct.extension", + Source: "ignored-source-filter", + SourceConfig: &SourceConfig{ + Name: "direct", + Type: SourceKindFile, + Location: registryPath, + }, + }) + require.NoError(t, err) + require.Len(t, extensions, 1) + require.Equal(t, "direct.extension", extensions[0].Id) + require.Equal(t, "direct", extensions[0].Source) +} + +func Test_FindExtensions_SourceConfigMissingFileReturnsError(t *testing.T) { + t.Parallel() + + manager := newTestManager(t) + + _, err := manager.FindExtensions(t.Context(), &FilterOptions{ + SourceConfig: &SourceConfig{ + Name: "missing", + Type: SourceKindFile, + Location: filepath.Join(t.TempDir(), "missing-registry.json"), + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "failed initializing extension source") +} + +func Test_FindExtensions_SourceConfigUnsupportedSchemaReturnsSuggestion(t *testing.T) { + t.Parallel() + + registryPath := writeExtensionRegistryFile(t, Registry{ + SchemaVersion: "2.0", + Extensions: []*ExtensionMetadata{}, + }) + + manager := newTestManager(t) + + _, err := manager.FindExtensions(t.Context(), &FilterOptions{ + SourceConfig: &SourceConfig{ + Name: "future", + Type: SourceKindFile, + Location: registryPath, + }, + }) + require.Error(t, err) + require.ErrorAs(t, err, new(*ErrUnsupportedRegistrySchema)) + require.ErrorAs(t, err, new(*errorhandler.ErrorWithSuggestion)) +} + +func Test_UpdateInstalled_UpdatesConfigAndInvalidatesCache(t *testing.T) { + t.Parallel() + + manager := newTestManager(t) + require.NoError(t, manager.userConfig.Set(installedConfigKey, map[string]*Extension{ + "test.extension": { + Id: "test.extension", + Version: "1.0.0", + }, + })) + + manager.installed = map[string]*Extension{ + "test.extension": { + Id: "test.extension", + Version: "1.0.0", + }, + } + + err := manager.UpdateInstalled(&Extension{ + Id: "test.extension", + Version: "2.0.0", + }) + require.NoError(t, err) + require.Nil(t, manager.installed) + + updated, err := manager.GetInstalled(FilterOptions{Id: "test.extension"}) + require.NoError(t, err) + require.Equal(t, "2.0.0", updated.Version) +} + +func Test_UpdateInstalled_MissingExtension(t *testing.T) { + t.Parallel() + + manager := newTestManager(t) + + err := manager.UpdateInstalled(&Extension{Id: "missing"}) + require.ErrorIs(t, err, ErrInstalledExtensionNotFound) +} + +func Test_InvalidateSourceCache(t *testing.T) { + t.Parallel() + + manager := newTestManager(t) + manager.sources = []Source{&mockSource{name: "cached"}} + + manager.InvalidateSourceCache() + + require.Nil(t, manager.sources) +} + +func Test_HasMetadataCapability(t *testing.T) { + t.Parallel() + + manager := newTestManager(t) + require.NoError(t, manager.userConfig.Set(installedConfigKey, map[string]*Extension{ + "metadata.extension": { + Id: "metadata.extension", + Capabilities: []CapabilityType{MetadataCapability}, + }, + "plain.extension": { + Id: "plain.extension", + }, + })) + + require.True(t, manager.HasMetadataCapability("metadata.extension")) + require.False(t, manager.HasMetadataCapability("plain.extension")) + require.False(t, manager.HasMetadataCapability("missing.extension")) +} + +func newTestManager(t *testing.T) *Manager { + t.Helper() + + mockContext := mocks.NewMockContext(t.Context()) + userConfigManager := config.NewUserConfigManager(mockContext.ConfigManager) + sourceManager := NewSourceManager(mockContext.Container, userConfigManager, mockContext.HttpClient) + lazyRunner := lazy.NewLazy(func() (*Runner, error) { + return NewRunner(mockContext.CommandRunner), nil + }) + manager, err := NewManager(userConfigManager, sourceManager, lazyRunner, mockContext.HttpClient) + require.NoError(t, err) + + return manager +} + +func writeExtensionRegistryFile(t *testing.T, registry Registry) string { + t.Helper() + + data, err := json.Marshal(registry) + require.NoError(t, err) + + registryPath := filepath.Join(t.TempDir(), "registry.json") + require.NoError(t, os.WriteFile(registryPath, data, 0600)) + + return registryPath +} + // mockSource is a test implementation of the Source interface type mockSource struct { name string diff --git a/cli/azd/pkg/extensions/registry_version_test.go b/cli/azd/pkg/extensions/registry_version_test.go index 474c25aae9f..6bc61055c50 100644 --- a/cli/azd/pkg/extensions/registry_version_test.go +++ b/cli/azd/pkg/extensions/registry_version_test.go @@ -9,6 +9,7 @@ import ( "fmt" "testing" + "github.com/azure/azure-dev/cli/azd/pkg/errorhandler" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -321,6 +322,23 @@ func TestErrUnsupportedRegistrySchema_Formatting(t *testing.T) { assert.Equal(t, expected, err.Error()) } +func TestNewUnsupportedRegistrySchemaError(t *testing.T) { + schemaErr := &ErrUnsupportedRegistrySchema{ + SchemaVersion: "3.0", + MaxSupportedVersion: "1.0", + } + + err := NewUnsupportedRegistrySchemaError(schemaErr) + require.ErrorIs(t, err, schemaErr) + + suggestionErr, ok := errors.AsType[*errorhandler.ErrorWithSuggestion](err) + require.True(t, ok) + require.Equal(t, schemaErr.Error(), suggestionErr.Message) + require.Contains(t, suggestionErr.Suggestion, "Upgrade azd") + require.Len(t, suggestionErr.Links, 1) + require.Equal(t, "https://aka.ms/azd/install", suggestionErr.Links[0].URL) +} + func TestValidateRegistry_NilRegistry(t *testing.T) { result := ValidateRegistry(nil, false) require.NotNil(t, result) diff --git a/cli/azd/pkg/extensions/source_manager_test.go b/cli/azd/pkg/extensions/source_manager_test.go index 07b601226a0..46dd41bbc03 100644 --- a/cli/azd/pkg/extensions/source_manager_test.go +++ b/cli/azd/pkg/extensions/source_manager_test.go @@ -146,6 +146,13 @@ func TestSourceManager_List(t *testing.T) { require.Equal(t, expected, *sources[0]) } +func TestNormalizeSourceKey(t *testing.T) { + t.Parallel() + + require.Equal(t, "my-source", NormalizeSourceKey("My Source")) + require.Equal(t, "my.source", NormalizeSourceKey("My.Source")) +} + func TestSourceManager_CreateSource_Bundle(t *testing.T) { mockContext := mocks.NewMockContext(t.Context()) ctx := t.Context()