From 53fc1ddc6349677932119421e94b9017cdd8cef1 Mon Sep 17 00:00:00 2001 From: Dov Benyomin Sohacheski Date: Tue, 30 Jun 2026 11:47:27 +0300 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9C=A8=20Add=20unified=20seed=20engine,?= =?UTF-8?q?=20remove=20secrets=20vault=20subgroup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce `ws-cli seed` — a Go-native declarative projection engine that supersedes the q2-55 Ansible seed tier and the standalone secrets vault. internals/seed/: - version:v1 dest-keyed manifest (top-level `secrets:` map + `seeds:` SeedOp) - two tiers via a single resolved plan: bare FS-mirror verbatim copy + manifest behavior overlay (manifest wins per dest; one write per dest) - ops copy/merge/append/prepend; merge deep-merges json/yaml/toml (format inferred from dest extension, lists replace, scalar-vs-map = error) - closed-var + ${secrets.NAME} templating, unknown token fails loud - ownership-boundary allow-list (st_uid==geteuid() on nearest existing ancestor; no system gate) + non-blocking consumed-dir notice - os.Root TOCTOU-safe atomic writer with final-symlink refusal - secret-bearing -> 0600, lazy/per-entry fail-closed master key cmd/seed/: `seed apply [dest...] [--force]` (boot hook = no-arg) + `seed ls`. Remove `cmd/secrets/vault/` and orphaned `internals/secrets/vault.go` (+ test); keep crypto.go/key.go and `secrets {encrypt,decrypt,generate}`. Add github.com/pelletier/go-toml/v2 v2.4.2. Bump version 0.0.66 -> 0.0.67. --- cmd/info/version.go | 2 +- cmd/root.go | 2 + cmd/secrets/secrets.go | 5 +- cmd/secrets/vault/decrypt.go | 103 ---- cmd/secrets/vault/ls.go | 50 -- cmd/secrets/vault/rotate.go | 72 --- cmd/secrets/vault/vault.go | 14 - cmd/seed/apply.go | 49 ++ cmd/seed/ls.go | 54 ++ cmd/seed/seed.go | 17 + cmd/seed/seed_test.go | 77 +++ go.mod | 1 + go.sum | 2 + internals/secrets/vault.go | 407 --------------- internals/secrets/vault_test.go | 884 -------------------------------- internals/seed/apply.go | 237 +++++++++ internals/seed/manifest.go | 82 +++ internals/seed/manifest_test.go | 51 ++ internals/seed/merge.go | 97 ++++ internals/seed/merge_test.go | 70 +++ internals/seed/mirror.go | 45 ++ internals/seed/op.go | 23 + internals/seed/ownership.go | 79 +++ internals/seed/resolve.go | 185 +++++++ internals/seed/seed_test.go | 491 ++++++++++++++++++ internals/seed/template.go | 55 ++ internals/seed/write.go | 77 +++ 27 files changed, 1697 insertions(+), 1534 deletions(-) delete mode 100644 cmd/secrets/vault/decrypt.go delete mode 100644 cmd/secrets/vault/ls.go delete mode 100644 cmd/secrets/vault/rotate.go delete mode 100644 cmd/secrets/vault/vault.go create mode 100644 cmd/seed/apply.go create mode 100644 cmd/seed/ls.go create mode 100644 cmd/seed/seed.go create mode 100644 cmd/seed/seed_test.go delete mode 100644 internals/secrets/vault.go delete mode 100644 internals/secrets/vault_test.go create mode 100644 internals/seed/apply.go create mode 100644 internals/seed/manifest.go create mode 100644 internals/seed/manifest_test.go create mode 100644 internals/seed/merge.go create mode 100644 internals/seed/merge_test.go create mode 100644 internals/seed/mirror.go create mode 100644 internals/seed/op.go create mode 100644 internals/seed/ownership.go create mode 100644 internals/seed/resolve.go create mode 100644 internals/seed/seed_test.go create mode 100644 internals/seed/template.go create mode 100644 internals/seed/write.go diff --git a/cmd/info/version.go b/cmd/info/version.go index ca2ecd6..287a130 100644 --- a/cmd/info/version.go +++ b/cmd/info/version.go @@ -1,3 +1,3 @@ package info -var Version = "0.0.66" +var Version = "0.0.67" diff --git a/cmd/root.go b/cmd/root.go index 7a6d0a0..1027623 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -12,6 +12,7 @@ import ( "github.com/kloudkit/ws-cli/cmd/log" "github.com/kloudkit/ws-cli/cmd/logs" "github.com/kloudkit/ws-cli/cmd/secrets" + "github.com/kloudkit/ws-cli/cmd/seed" "github.com/kloudkit/ws-cli/cmd/serve" "github.com/kloudkit/ws-cli/cmd/show" "github.com/kloudkit/ws-cli/cmd/template" @@ -57,5 +58,6 @@ func init() { log.LogCmd, logs.LogsCmd, secrets.SecretsCmd, + seed.SeedCmd, ) } diff --git a/cmd/secrets/secrets.go b/cmd/secrets/secrets.go index f634f30..8db7c41 100644 --- a/cmd/secrets/secrets.go +++ b/cmd/secrets/secrets.go @@ -1,13 +1,12 @@ package secrets import ( - "github.com/kloudkit/ws-cli/cmd/secrets/vault" "github.com/spf13/cobra" ) var SecretsCmd = &cobra.Command{ Use: "secrets", - Short: "Manage encryption, decryption, and vaults for secrets", + Short: "Manage encryption and decryption of secrets", } func init() { @@ -17,5 +16,5 @@ func init() { SecretsCmd.PersistentFlags().Bool("force", false, "Overwrite existing files") SecretsCmd.PersistentFlags().Bool("raw", false, "Output without styling") - SecretsCmd.AddCommand(encryptCmd, decryptCmd, generateCmd, vault.VaultCmd) + SecretsCmd.AddCommand(encryptCmd, decryptCmd, generateCmd) } diff --git a/cmd/secrets/vault/decrypt.go b/cmd/secrets/vault/decrypt.go deleted file mode 100644 index 9d20def..0000000 --- a/cmd/secrets/vault/decrypt.go +++ /dev/null @@ -1,103 +0,0 @@ -package vault - -import ( - "fmt" - "slices" - "strings" - - internalSecrets "github.com/kloudkit/ws-cli/internals/secrets" - "github.com/kloudkit/ws-cli/internals/styles" - "github.com/spf13/cobra" -) - -var decryptCmd = &cobra.Command{ - Use: "decrypt", - Short: "Decrypt vault secrets and write to destinations", - Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - inputFile, _ := cmd.Flags().GetString("input") - masterKeyFlag, _ := cmd.Flags().GetString("master") - keys, _ := cmd.Flags().GetStringArray("key") - force, _ := cmd.Flags().GetBool("force") - raw, _ := cmd.Flags().GetBool("raw") - stdout, _ := cmd.Flags().GetBool("stdout") - modeOverride, _ := cmd.Flags().GetString("mode") - - vaultPath, err := internalSecrets.ResolveVaultPath(inputFile) - if err != nil { - return err - } - - masterKey, err := internalSecrets.ResolveMasterKey(masterKeyFlag) - if err != nil { - return err - } - - vault, err := internalSecrets.LoadVault(vaultPath) - if err != nil { - return err - } - - opts := internalSecrets.ProcessOptions{ - MasterKey: masterKey, - Keys: keys, - Stdout: stdout, - Raw: raw, - Force: force, - ModeOverride: modeOverride, - } - - results, err := internalSecrets.ProcessVault(vault, opts) - if err != nil { - return err - } - - if stdout { - printStdoutResults(cmd, results, raw) - return nil - } - - if raw { - return nil - } - - printVaultSuccess(cmd, results) - return nil - }, -} - -func init() { - decryptCmd.Flags().StringArray("key", []string{}, "Decrypt only specified key") - decryptCmd.Flags().Bool("stdout", false, "Output decrypted values to stdout") -} - -func sortedKeys(m map[string]string) []string { - keys := make([]string, 0, len(m)) - for k := range m { - keys = append(keys, k) - } - slices.Sort(keys) - return keys -} - -func printStdoutResults(cmd *cobra.Command, results map[string]string, raw bool) { - for _, key := range sortedKeys(results) { - value := results[key] - output := internalSecrets.FormatSecretForStdout(key, value, raw) - fmt.Fprint(cmd.OutOrStdout(), output) - } -} - -func printVaultSuccess(cmd *cobra.Command, results map[string]string) { - fmt.Fprintln(cmd.OutOrStdout(), styles.Success().Render("✓ Vault processed successfully")) - for _, key := range sortedKeys(results) { - dest := results[key] - displayDest := dest - if after, ok := strings.CutPrefix(dest, "env:"); ok { - displayDest = fmt.Sprintf("env:%s", after) - } - fmt.Fprintf(cmd.OutOrStdout(), " %s → %s\n", - styles.Code().Render(key), - styles.Muted().Render(displayDest)) - } -} diff --git a/cmd/secrets/vault/ls.go b/cmd/secrets/vault/ls.go deleted file mode 100644 index 43cda09..0000000 --- a/cmd/secrets/vault/ls.go +++ /dev/null @@ -1,50 +0,0 @@ -package vault - -import ( - "fmt" - - internalSecrets "github.com/kloudkit/ws-cli/internals/secrets" - "github.com/kloudkit/ws-cli/internals/styles" - "github.com/spf13/cobra" -) - -var lsCmd = &cobra.Command{ - Use: "ls", - Short: "List secrets in a vault", - Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - inputFile, _ := cmd.Flags().GetString("input") - raw, _ := cmd.Flags().GetBool("raw") - - vaultPath, err := internalSecrets.ResolveVaultPath(inputFile) - if err != nil { - return err - } - - vault, err := internalSecrets.LoadRawVault(vaultPath) - if err != nil { - return err - } - - entries := internalSecrets.ListVault(vault) - if len(entries) == 0 { - fmt.Fprintln(cmd.OutOrStdout(), styles.Muted().Render("No secrets in vault")) - return nil - } - - if raw { - for _, e := range entries { - fmt.Fprintf(cmd.OutOrStdout(), "%s\t%s\t%s\n", e.Name, e.Type, e.Destination) - } - return nil - } - - t := styles.Table("Name", "Type", "Destination") - for _, e := range entries { - t.Row(e.Name, e.Type, e.Destination) - } - fmt.Fprintln(cmd.OutOrStdout(), t.Render()) - - return nil - }, -} diff --git a/cmd/secrets/vault/rotate.go b/cmd/secrets/vault/rotate.go deleted file mode 100644 index 97e269a..0000000 --- a/cmd/secrets/vault/rotate.go +++ /dev/null @@ -1,72 +0,0 @@ -package vault - -import ( - "fmt" - - internalSecrets "github.com/kloudkit/ws-cli/internals/secrets" - "github.com/kloudkit/ws-cli/internals/styles" - "github.com/spf13/cobra" -) - -var rotateCmd = &cobra.Command{ - Use: "rotate", - Short: "Re-encrypt vault secrets with a new master key", - Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - inputFile, _ := cmd.Flags().GetString("input") - masterKeyFlag, _ := cmd.Flags().GetString("master") - newMasterFlag, _ := cmd.Flags().GetString("new-master") - raw, _ := cmd.Flags().GetBool("raw") - - vaultPath, err := internalSecrets.ResolveVaultPath(inputFile) - if err != nil { - return err - } - - oldKey, err := internalSecrets.ResolveMasterKey(masterKeyFlag) - if err != nil { - return err - } - - newKey, err := internalSecrets.ResolveMasterKey(newMasterFlag) - if err != nil { - return fmt.Errorf("new master key: %w", err) - } - - vault, err := internalSecrets.LoadRawVault(vaultPath) - if err != nil { - return err - } - - fileRefs, err := internalSecrets.RotateVault(vault, oldKey, newKey) - if err != nil { - return err - } - - if err := internalSecrets.SaveVault(vaultPath, vault); err != nil { - return err - } - - if raw { - fmt.Fprintf(cmd.OutOrStdout(), "%d\n", len(vault.Secrets)) - return nil - } - - if len(fileRefs) > 0 { - fmt.Fprintln(cmd.OutOrStdout(), styles.Warning().Render("⚠ The following secrets had file: references and are now inlined:")) - for _, name := range fileRefs { - fmt.Fprintf(cmd.OutOrStdout(), " %s\n", styles.Code().Render(name)) - } - } - - fmt.Fprintf(cmd.OutOrStdout(), "%s\n", - styles.Success().Render(fmt.Sprintf("✓ Rotated %d secret(s)", len(vault.Secrets)))) - - return nil - }, -} - -func init() { - rotateCmd.Flags().String("new-master", "", "New master key or path to key file") - rotateCmd.MarkFlagRequired("new-master") -} diff --git a/cmd/secrets/vault/vault.go b/cmd/secrets/vault/vault.go deleted file mode 100644 index 2058abf..0000000 --- a/cmd/secrets/vault/vault.go +++ /dev/null @@ -1,14 +0,0 @@ -package vault - -import "github.com/spf13/cobra" - -var VaultCmd = &cobra.Command{ - Use: "vault", - Short: "Manage vault secrets", -} - -func init() { - VaultCmd.PersistentFlags().String("input", "", "Path to vault file") - - VaultCmd.AddCommand(lsCmd, decryptCmd, rotateCmd) -} diff --git a/cmd/seed/apply.go b/cmd/seed/apply.go new file mode 100644 index 0000000..1fc1644 --- /dev/null +++ b/cmd/seed/apply.go @@ -0,0 +1,49 @@ +package seed + +import ( + "io" + "os" + + "github.com/kloudkit/ws-cli/internals/seed" + "github.com/spf13/cobra" + "golang.org/x/term" +) + +var applyCmd = &cobra.Command{ + Use: "apply [dest...]", + Short: "Project seed content onto the filesystem", + RunE: runApply, +} + +func runApply(cmd *cobra.Command, args []string) error { + source, _ := cmd.Flags().GetString("source") + force, _ := cmd.Flags().GetBool("force") + master, _ := cmd.Flags().GetString("master") + + resolved, err := seed.ResolveSource(source) + if err != nil { + return err + } + + return seed.Apply(seed.Options{ + Source: resolved, + Force: force, + Dests: args, + MasterKey: master, + Out: cmd.OutOrStdout(), + Styled: isTerminal(cmd.OutOrStdout()), + }) +} + +func isTerminal(out io.Writer) bool { + file, ok := out.(*os.File) + + return ok && term.IsTerminal(int(file.Fd())) +} + +func init() { + applyCmd.Flags().Bool("force", false, "Overwrite existing destinations") + applyCmd.Flags().String("master", "", "Master key or path to key file") + + SeedCmd.AddCommand(applyCmd) +} diff --git a/cmd/seed/ls.go b/cmd/seed/ls.go new file mode 100644 index 0000000..a42e123 --- /dev/null +++ b/cmd/seed/ls.go @@ -0,0 +1,54 @@ +package seed + +import ( + "strings" + + "github.com/kloudkit/ws-cli/internals/seed" + "github.com/kloudkit/ws-cli/internals/styles" + "github.com/spf13/cobra" +) + +var lsCmd = &cobra.Command{ + Use: "ls", + Short: "List seed destinations and their behaviors", + RunE: runLs, +} + +func runLs(cmd *cobra.Command, args []string) error { + source, _ := cmd.Flags().GetString("source") + + resolved, err := seed.ResolveSource(source) + if err != nil { + return err + } + + plan, err := seed.BuildPlan(resolved, false) + if err != nil { + return err + } + + out := cmd.OutOrStdout() + for _, op := range plan.Ops { + styles.PrintKeyValue(out, op.Dest, describe(op)) + } + + return nil +} + +func describe(op seed.ResolvedOp) string { + parts := []string{string(op.Op)} + + if op.Secret { + parts = append(parts, "secret") + } + + if op.Template { + parts = append(parts, "template") + } + + return strings.Join(parts, " ") +} + +func init() { + SeedCmd.AddCommand(lsCmd) +} diff --git a/cmd/seed/seed.go b/cmd/seed/seed.go new file mode 100644 index 0000000..b244032 --- /dev/null +++ b/cmd/seed/seed.go @@ -0,0 +1,17 @@ +package seed + +import ( + "github.com/kloudkit/ws-cli/internals/config" + "github.com/spf13/cobra" +) + +var SeedCmd = &cobra.Command{ + Use: "seed", + Short: "Project declarative content onto the filesystem", +} + +func init() { + source, _ := config.Resolve("seed", "source") + + SeedCmd.PersistentFlags().String("source", source, "Seed source directory") +} diff --git a/cmd/seed/seed_test.go b/cmd/seed/seed_test.go new file mode 100644 index 0000000..5418c7f --- /dev/null +++ b/cmd/seed/seed_test.go @@ -0,0 +1,77 @@ +package seed + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "gotest.tools/v3/assert" +) + +func resetCommandFlags(cmd *cobra.Command) { + cmd.Flags().VisitAll(func(flag *pflag.Flag) { + flag.Value.Set(flag.DefValue) + flag.Changed = false + }) + + for _, c := range cmd.Commands() { + resetCommandFlags(c) + } +} + +func run(t *testing.T, args ...string) string { + t.Helper() + resetCommandFlags(SeedCmd) + + buffer := new(bytes.Buffer) + SeedCmd.SetOut(buffer) + SeedCmd.SetErr(buffer) + SeedCmd.SetArgs(args) + + assert.NilError(t, SeedCmd.Execute()) + + return buffer.String() +} + +func TestSeedCommand(t *testing.T) { + t.Run("Apply", func(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("WS__INTERNAL_ENV_REFERENCE", filepath.Join(t.TempDir(), "absent.yaml")) + + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "out.txt") + + manifest := fmt.Sprintf("version: v1\nseeds:\n %s:\n mode: \"0o644\"\n content: \"cli\\n\"\n", dest) + assert.NilError(t, os.WriteFile(filepath.Join(source, ".seed.yaml"), []byte(manifest), 0o644)) + + output := run(t, "apply", "--source", source) + + got, err := os.ReadFile(dest) + assert.NilError(t, err) + assert.Equal(t, string(got), "cli\n") + assert.Assert(t, strings.Contains(output, "Seeded ["+dest+"]")) + }) + + t.Run("List", func(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("WS__INTERNAL_ENV_REFERENCE", filepath.Join(t.TempDir(), "absent.yaml")) + + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "out.txt") + + manifest := fmt.Sprintf("version: v1\nseeds:\n %s:\n secret: true\n", dest) + assert.NilError(t, os.WriteFile(filepath.Join(source, ".seed.yaml"), []byte(manifest), 0o644)) + + output := run(t, "ls", "--source", source) + + assert.Assert(t, strings.Contains(output, dest)) + assert.Assert(t, strings.Contains(output, "secret")) + }) +} diff --git a/go.mod b/go.mod index 6a319e4..bb5fbd8 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/muesli/mango-pflag v0.2.0 // indirect github.com/muesli/roff v0.1.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pelletier/go-toml/v2 v2.4.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.20.1 // indirect diff --git a/go.sum b/go.sum index b3e8eb8..a7e91f3 100644 --- a/go.sum +++ b/go.sum @@ -80,6 +80,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8= github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/pelletier/go-toml/v2 v2.4.2 h1:M2fKKbmyvI+hGId/D0W64qDBMVhJnNR10O5gIbMc//Q= +github.com/pelletier/go-toml/v2 v2.4.2/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= diff --git a/internals/secrets/vault.go b/internals/secrets/vault.go deleted file mode 100644 index eb1ac5d..0000000 --- a/internals/secrets/vault.go +++ /dev/null @@ -1,407 +0,0 @@ -package secrets - -import ( - "fmt" - "os" - "path/filepath" - "slices" - "strings" - - "github.com/kloudkit/ws-cli/internals/config" - "github.com/kloudkit/ws-cli/internals/env" - internalIO "github.com/kloudkit/ws-cli/internals/io" - "github.com/kloudkit/ws-cli/internals/path" - "gopkg.in/yaml.v3" -) - -type VaultSecret struct { - Type string `yaml:"type,omitempty"` - Encrypted string `yaml:"encrypted"` - Destination string `yaml:"destination"` - Mode string `yaml:"mode,omitempty"` - Force bool `yaml:"force,omitempty"` -} - -type Vault struct { - Secrets map[string]VaultSecret `yaml:"secrets"` -} - -const ( - TypeGeneric = "generic" - TypeSSH = "ssh" - TypeEnv = "env" - TypeKubeconfig = "kubeconfig" - TypeDockerConfigJSON = "dockerconfigjson" -) - -type SecretTypeConfig struct { - DefaultMode string - DefaultDirectory string -} - -var SecretTypeConfigs = map[string]SecretTypeConfig{ - TypeGeneric: { - DefaultMode: "0o600", - DefaultDirectory: "", - }, - TypeSSH: { - DefaultMode: "0o600", - DefaultDirectory: "~/.ssh", - }, - TypeEnv: { - DefaultMode: "0o644", - DefaultDirectory: "", - }, - TypeKubeconfig: { - DefaultMode: "0o600", - DefaultDirectory: "~/.kube", - }, - TypeDockerConfigJSON: { - DefaultMode: "0o600", - DefaultDirectory: "~/.docker", - }, -} - -func LoadVault(path string) (*Vault, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("failed to read vault file %q: %w", path, err) - } - - var vault Vault - if err := yaml.Unmarshal(data, &vault); err != nil { - return nil, fmt.Errorf("failed to unmarshal vault yaml: %w", err) - } - - if vault.Secrets == nil { - vault.Secrets = make(map[string]VaultSecret) - } - - for name, secret := range vault.Secrets { - if secret.Type == "" { - secret.Type = TypeGeneric - } - - if secret.Mode == "" { - if config, ok := SecretTypeConfigs[secret.Type]; ok { - secret.Mode = config.DefaultMode - } - } - - resolvedDest, err := ResolveDestination(secret) - if err != nil { - return nil, fmt.Errorf("secret %q: %w", name, err) - } - secret.Destination = resolvedDest - - vault.Secrets[name] = secret - } - - return &vault, nil -} - -func ResolveDestination(secret VaultSecret) (string, error) { - if secret.Type == TypeEnv { - return secret.Destination, nil - } - - config, ok := SecretTypeConfigs[secret.Type] - if !ok { - return "", fmt.Errorf("unknown type %q", secret.Type) - } - - if filepath.IsAbs(secret.Destination) || strings.HasPrefix(secret.Destination, "~") { - resolved, err := path.Expand(secret.Destination) - if err != nil { - return "", fmt.Errorf("failed to expand path: %w", err) - } - return resolved, nil - } - - if config.DefaultDirectory == "" { - return "", fmt.Errorf("type %q requires an absolute path", secret.Type) - } - - fullPath := filepath.Join(config.DefaultDirectory, secret.Destination) - resolved, err := path.Expand(fullPath) - if err != nil { - return "", fmt.Errorf("failed to expand path: %w", err) - } - return resolved, nil -} - -const DefaultVaultPath = "~/.ws/vault/secrets.yaml" - -func ResolveVaultPath(inputFlag string) (string, error) { - if inputFlag != "" { - return inputFlag, nil - } - - expanded, err := path.Expand(DefaultVaultPath) - if err != nil { - return "", fmt.Errorf("failed to resolve default vault path: %w", err) - } - - if _, err := os.Stat(expanded); err == nil { - return expanded, nil - } - - return "", fmt.Errorf("vault file not specified (use --input or place at %s)", DefaultVaultPath) -} - -func ValidateSecret(name string, secret VaultSecret) error { - if secret.Encrypted == "" { - return fmt.Errorf("secret %q: encrypted value is required", name) - } - - if secret.Destination == "" { - return fmt.Errorf("secret %q: destination is required", name) - } - - validTypes := []string{ - TypeGeneric, - TypeSSH, - TypeEnv, - TypeKubeconfig, - TypeDockerConfigJSON, - } - - if !slices.Contains(validTypes, secret.Type) { - return fmt.Errorf("secret %q: invalid type %q", name, secret.Type) - } - - if secret.Type == TypeEnv { - if !env.IsValidName(secret.Destination) { - return fmt.Errorf("secret %q: invalid environment variable name %q (must start with letter/underscore and contain only alphanumerics and underscores)", name, secret.Destination) - } - } else if !filepath.IsAbs(secret.Destination) { - return fmt.Errorf("secret %q: invalid destination path", name) - } - - return nil -} - -func GetSecretKeys(vault *Vault, requestedKeys []string) []string { - if len(requestedKeys) > 0 { - return requestedKeys - } - - keys := make([]string, 0, len(vault.Secrets)) - for key := range vault.Secrets { - keys = append(keys, key) - } - slices.Sort(keys) - - return keys -} - -type ProcessOptions struct { - MasterKey []byte - Keys []string - Stdout bool - Raw bool - Force bool - ModeOverride string -} - -func ProcessVault(vault *Vault, opts ProcessOptions) (map[string]string, error) { - results := make(map[string]string) - keys := GetSecretKeys(vault, opts.Keys) - - for _, key := range keys { - secret, exists := vault.Secrets[key] - if !exists { - return nil, fmt.Errorf("secret %q not found in vault", key) - } - - if err := ValidateSecret(key, secret); err != nil { - return nil, err - } - - effectiveForce := opts.Force || secret.Force - - encryptedValue, err := ResolveEncryptedValue(secret.Encrypted) - if err != nil { - return nil, fmt.Errorf("failed to resolve encrypted value for %q: %w", key, err) - } - - decrypted, err := Decrypt(encryptedValue, opts.MasterKey) - if err != nil { - return nil, fmt.Errorf("failed to decrypt secret %q: %w", key, err) - } - - if opts.Stdout { - results[key] = string(decrypted) - continue - } - - mode := secret.Mode - if opts.ModeOverride != "" { - mode = opts.ModeOverride - } - - if secret.Type == TypeEnv { - if err := ProcessEnvSecret(secret.Destination, decrypted, effectiveForce); err != nil { - return nil, fmt.Errorf("failed to process env secret %q: %w", key, err) - } - results[key] = fmt.Sprintf("env:%s", secret.Destination) - } else { - if err := internalIO.WriteSecureFile(secret.Destination, decrypted, mode, effectiveForce); err != nil { - return nil, fmt.Errorf("failed to write secret %q: %w", key, err) - } - results[key] = secret.Destination - } - } - - return results, nil -} - -func findAndReplaceEnvVar(lines []string, envVarName, value string, force bool) ([]string, error) { - exportLine := fmt.Sprintf("export %s=%q", envVarName, value) - found := false - - for i, line := range lines { - trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "export "+envVarName+"=") || - strings.HasPrefix(trimmed, envVarName+"=") { - if !force { - return nil, fmt.Errorf("environment variable %q already exists, use --force to overwrite", envVarName) - } - lines[i] = exportLine - found = true - break - } - } - - if !found { - lines = append(lines, exportLine) - } - - return lines, nil -} - -func ProcessEnvSecret(envVarName string, value []byte, force bool) error { - envFilePath, err := path.Expand(config.DefaultEnvFilePath) - if err != nil { - return err - } - - var existingContent []byte - if internalIO.FileExists(envFilePath) { - data, err := os.ReadFile(envFilePath) - if err != nil { - return fmt.Errorf("failed to read env file: %w", err) - } - existingContent = data - } - - lines := strings.Split(string(existingContent), "\n") - - lines, err = findAndReplaceEnvVar(lines, envVarName, string(value), force) - if err != nil { - return fmt.Errorf("%w in %s", err, envFilePath) - } - - content := strings.Join(lines, "\n") - if !strings.HasSuffix(content, "\n") { - content += "\n" - } - - if err := os.WriteFile(envFilePath, []byte(content), 0o644); err != nil { - return fmt.Errorf("failed to write env file: %w", err) - } - - return nil -} - -func FormatSecretForStdout(key string, value string, raw bool) string { - if raw { - return value - } - - return fmt.Sprintf("[%s]\n%s\n", key, value) -} - -type VaultEntry struct { - Name string - Type string - Destination string -} - -func ListVault(vault *Vault) []VaultEntry { - entries := make([]VaultEntry, 0, len(vault.Secrets)) - for name, secret := range vault.Secrets { - entries = append(entries, VaultEntry{ - Name: name, - Type: secret.Type, - Destination: secret.Destination, - }) - } - slices.SortFunc(entries, func(a, b VaultEntry) int { - return strings.Compare(a.Name, b.Name) - }) - return entries -} - -func LoadRawVault(path string) (*Vault, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("failed to read vault file %q: %w", path, err) - } - - var vault Vault - if err := yaml.Unmarshal(data, &vault); err != nil { - return nil, fmt.Errorf("failed to unmarshal vault yaml: %w", err) - } - - if vault.Secrets == nil { - vault.Secrets = make(map[string]VaultSecret) - } - - return &vault, nil -} - -func RotateVault(vault *Vault, oldKey, newKey []byte) ([]string, error) { - var fileRefs []string - - for name, secret := range vault.Secrets { - encryptedValue := secret.Encrypted - if strings.HasPrefix(encryptedValue, "file:") { - fileRefs = append(fileRefs, name) - } - - resolved, err := ResolveEncryptedValue(encryptedValue) - if err != nil { - return nil, fmt.Errorf("secret %q: %w", name, err) - } - - decrypted, err := Decrypt(resolved, oldKey) - if err != nil { - return nil, fmt.Errorf("secret %q: failed to decrypt: %w", name, err) - } - - reEncrypted, err := Encrypt(decrypted, newKey) - if err != nil { - return nil, fmt.Errorf("secret %q: failed to re-encrypt: %w", name, err) - } - - secret.Encrypted = reEncrypted - vault.Secrets[name] = secret - } - - slices.Sort(fileRefs) - return fileRefs, nil -} - -func SaveVault(path string, vault *Vault) error { - data, err := yaml.Marshal(vault) - if err != nil { - return fmt.Errorf("failed to marshal vault yaml: %w", err) - } - - if err := os.WriteFile(path, data, 0o600); err != nil { - return fmt.Errorf("failed to write vault file %q: %w", path, err) - } - - return nil -} diff --git a/internals/secrets/vault_test.go b/internals/secrets/vault_test.go deleted file mode 100644 index 4d38e42..0000000 --- a/internals/secrets/vault_test.go +++ /dev/null @@ -1,884 +0,0 @@ -package secrets - -import ( - "os" - "path/filepath" - "strings" - "testing" - - "gotest.tools/v3/assert" -) - -func TestLoadVault(t *testing.T) { - t.Run("ValidVault", func(t *testing.T) { - vaultContent := ` -secrets: - db_password: - encrypted: "test$encrypted" - destination: "/etc/db/password" - ssh_key: - type: "ssh" - encrypted: "test$encrypted" - destination: "/home/user/.ssh/id_rsa" -` - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err := os.WriteFile(vaultFile, []byte(vaultContent), 0600) - assert.NilError(t, err) - - vault, err := LoadVault(vaultFile) - assert.NilError(t, err) - assert.Equal(t, 2, len(vault.Secrets)) - assert.Equal(t, TypeGeneric, vault.Secrets["db_password"].Type) - assert.Equal(t, TypeSSH, vault.Secrets["ssh_key"].Type) - assert.Equal(t, "0o600", vault.Secrets["db_password"].Mode) - }) - - t.Run("EmptyVault", func(t *testing.T) { - vaultContent := `secrets: {}` - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err := os.WriteFile(vaultFile, []byte(vaultContent), 0600) - assert.NilError(t, err) - - vault, err := LoadVault(vaultFile) - assert.NilError(t, err) - assert.Equal(t, 0, len(vault.Secrets)) - }) - - t.Run("InvalidYAML", func(t *testing.T) { - vaultContent := `invalid: yaml: content:` - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err := os.WriteFile(vaultFile, []byte(vaultContent), 0600) - assert.NilError(t, err) - - _, err = LoadVault(vaultFile) - assert.ErrorContains(t, err, "failed to unmarshal") - }) - - t.Run("FileNotFound", func(t *testing.T) { - _, err := LoadVault("/nonexistent/vault.yaml") - assert.ErrorContains(t, err, "failed to read vault file") - }) - - t.Run("RelativePathResolution", func(t *testing.T) { - homeDir, err := os.UserHomeDir() - assert.NilError(t, err) - - vaultContent := ` -secrets: - ssh_key: - type: "ssh" - encrypted: "test$encrypted" - destination: "github.com/id_ed25519" - kubeconfig: - type: "kubeconfig" - encrypted: "test$encrypted" - destination: "config" -` - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err = os.WriteFile(vaultFile, []byte(vaultContent), 0600) - assert.NilError(t, err) - - vault, err := LoadVault(vaultFile) - assert.NilError(t, err) - assert.Equal(t, 2, len(vault.Secrets)) - assert.Equal(t, filepath.Join(homeDir, ".ssh/github.com/id_ed25519"), vault.Secrets["ssh_key"].Destination) - assert.Equal(t, filepath.Join(homeDir, ".kube/config"), vault.Secrets["kubeconfig"].Destination) - }) - - t.Run("GenericTypeRequiresAbsolute", func(t *testing.T) { - vaultContent := ` -secrets: - generic_secret: - type: "generic" - encrypted: "test$encrypted" - destination: "relative/path" -` - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err := os.WriteFile(vaultFile, []byte(vaultContent), 0600) - assert.NilError(t, err) - - _, err = LoadVault(vaultFile) - assert.ErrorContains(t, err, "requires an absolute path") - }) -} - -func TestValidateSecret(t *testing.T) { - tests := []struct { - name string - secretName string - secret VaultSecret - errorContains string - }{ - { - name: "Valid", - secretName: "test", - secret: VaultSecret{ - Type: TypeGeneric, - Encrypted: "encrypted$value", - Destination: "/etc/test", - }, - errorContains: "", - }, - { - name: "MissingEncrypted", - secretName: "test", - secret: VaultSecret{ - Type: TypeGeneric, - Destination: "/etc/test", - }, - errorContains: "encrypted value is required", - }, - { - name: "MissingDestination", - secretName: "test", - secret: VaultSecret{ - Type: TypeGeneric, - Encrypted: "encrypted$value", - }, - errorContains: "destination is required", - }, - { - name: "InvalidType", - secretName: "test", - secret: VaultSecret{ - Type: "invalid", - Encrypted: "encrypted$value", - Destination: "/etc/test", - }, - errorContains: "invalid type", - }, - { - name: "RelativePathGenericType", - secretName: "test", - secret: VaultSecret{ - Type: TypeGeneric, - Encrypted: "encrypted$value", - Destination: "relative/path", - }, - errorContains: "invalid destination path", - }, - { - name: "EnvTypeValid", - secretName: "test", - secret: VaultSecret{ - Type: TypeEnv, - Encrypted: "encrypted$value", - Destination: "MY_VAR", - }, - errorContains: "", - }, - { - name: "EnvTypeValidUnderscore", - secretName: "test", - secret: VaultSecret{ - Type: TypeEnv, - Encrypted: "encrypted$value", - Destination: "_MY_VAR", - }, - errorContains: "", - }, - { - name: "EnvTypeValidWithNumbers", - secretName: "test", - secret: VaultSecret{ - Type: TypeEnv, - Encrypted: "encrypted$value", - Destination: "MY_VAR_123", - }, - errorContains: "", - }, - { - name: "EnvTypeInvalidStartsWithNumber", - secretName: "test", - secret: VaultSecret{ - Type: TypeEnv, - Encrypted: "encrypted$value", - Destination: "123_VAR", - }, - errorContains: "invalid environment variable name", - }, - { - name: "EnvTypeInvalidHyphen", - secretName: "test", - secret: VaultSecret{ - Type: TypeEnv, - Encrypted: "encrypted$value", - Destination: "MY-VAR", - }, - errorContains: "invalid environment variable name", - }, - { - name: "EnvTypeInvalidDot", - secretName: "test", - secret: VaultSecret{ - Type: TypeEnv, - Encrypted: "encrypted$value", - Destination: "MY.VAR", - }, - errorContains: "invalid environment variable name", - }, - { - name: "EnvTypeInvalidSpace", - secretName: "test", - secret: VaultSecret{ - Type: TypeEnv, - Encrypted: "encrypted$value", - Destination: "MY VAR", - }, - errorContains: "invalid environment variable name", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ValidateSecret(tt.secretName, tt.secret) - if tt.errorContains != "" { - assert.ErrorContains(t, err, tt.errorContains) - } else { - assert.NilError(t, err) - } - }) - } -} - -func TestGetSecretKeys(t *testing.T) { - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "key1": {}, - "key2": {}, - "key3": {}, - }, - } - - t.Run("AllKeys", func(t *testing.T) { - keys := GetSecretKeys(vault, []string{}) - assert.Equal(t, 3, len(keys)) - for i := 1; i < len(keys); i++ { - assert.Assert(t, keys[i-1] < keys[i], "keys should be sorted alphabetically") - } - }) - - t.Run("SpecificKeys", func(t *testing.T) { - keys := GetSecretKeys(vault, []string{"key1", "key3"}) - assert.Equal(t, 2, len(keys)) - assert.Equal(t, "key1", keys[0]) - assert.Equal(t, "key3", keys[1]) - }) -} - -func TestResolveVaultPath(t *testing.T) { - t.Run("FromFlag", func(t *testing.T) { - path, err := ResolveVaultPath("/path/to/vault.yaml") - assert.NilError(t, err) - assert.Equal(t, "/path/to/vault.yaml", path) - }) - - t.Run("FromConventionDefault", func(t *testing.T) { - home := t.TempDir() - t.Setenv("HOME", home) - - defaultVault := filepath.Join(home, ".ws", "vault", "secrets.yaml") - assert.NilError(t, os.MkdirAll(filepath.Dir(defaultVault), 0o755)) - assert.NilError(t, os.WriteFile(defaultVault, []byte("secrets:\n"), 0o600)) - - path, err := ResolveVaultPath("") - assert.NilError(t, err) - assert.Equal(t, defaultVault, path) - }) - - t.Run("ConventionDefaultAbsent", func(t *testing.T) { - t.Setenv("HOME", t.TempDir()) - _, err := ResolveVaultPath("") - assert.ErrorContains(t, err, "vault file not specified") - assert.ErrorContains(t, err, "~/.ws/vault/secrets.yaml") - }) - - t.Run("EnvIsIgnored", func(t *testing.T) { - t.Setenv("HOME", t.TempDir()) - t.Setenv("WS_SECRETS_VAULT", "/env/ignored/vault.yaml") - _, err := ResolveVaultPath("") - assert.ErrorContains(t, err, "vault file not specified") - }) -} - -func TestFormatSecretForStdout(t *testing.T) { - t.Run("Raw", func(t *testing.T) { - output := FormatSecretForStdout("key", "value", true) - assert.Equal(t, "value", output) - }) - - t.Run("Formatted", func(t *testing.T) { - output := FormatSecretForStdout("key", "value", false) - assert.Equal(t, "[key]\nvalue\n", output) - }) -} - -func TestResolveDestination(t *testing.T) { - homeDir, err := os.UserHomeDir() - assert.NilError(t, err) - - tests := []struct { - name string - secret VaultSecret - expected string - errorContains string - }{ - { - name: "EnvType", - secret: VaultSecret{ - Type: TypeEnv, - Destination: "MY_VAR", - }, - expected: "MY_VAR", - }, - { - name: "SSHRelativePath", - secret: VaultSecret{ - Type: TypeSSH, - Destination: "github.com/id_ed25519", - }, - expected: filepath.Join(homeDir, ".ssh", "github.com/id_ed25519"), - }, - { - name: "SSHAbsolutePath", - secret: VaultSecret{ - Type: TypeSSH, - Destination: "/custom/path/key", - }, - expected: "/custom/path/key", - }, - { - name: "SSHTildePath", - secret: VaultSecret{ - Type: TypeSSH, - Destination: "~/.ssh/custom/key", - }, - expected: filepath.Join(homeDir, ".ssh/custom/key"), - }, - { - name: "KubeconfigRelativePath", - secret: VaultSecret{ - Type: TypeKubeconfig, - Destination: "config", - }, - expected: filepath.Join(homeDir, ".kube/config"), - }, - { - name: "DockerConfigJSONRelativePath", - secret: VaultSecret{ - Type: TypeDockerConfigJSON, - Destination: "config.json", - }, - expected: filepath.Join(homeDir, ".docker/config.json"), - }, - { - name: "GenericAbsolutePath", - secret: VaultSecret{ - Type: TypeGeneric, - Destination: "/etc/secret", - }, - expected: "/etc/secret", - }, - { - name: "GenericRelativePath", - secret: VaultSecret{ - Type: TypeGeneric, - Destination: "relative", - }, - errorContains: "requires an absolute path", - }, - { - name: "UnknownType", - secret: VaultSecret{ - Type: "unknown", - Destination: "/etc/secret", - }, - errorContains: "unknown type", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := ResolveDestination(tt.secret) - if tt.errorContains != "" { - assert.ErrorContains(t, err, tt.errorContains) - } else { - assert.NilError(t, err) - assert.Equal(t, tt.expected, result) - } - }) - } -} - -func TestProcessEnvSecret(t *testing.T) { - t.Run("NewVariable", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - - t.Setenv("HOME", tmpDir) - - err := ProcessEnvSecret("NEW_VAR", []byte("secret_value"), false) - assert.NilError(t, err) - - content, err := os.ReadFile(envFile) - assert.NilError(t, err) - assert.Assert(t, len(content) > 0) - }) - - t.Run("ExistingWithoutForce", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - - t.Setenv("HOME", tmpDir) - - initialContent := `export EXISTING_VAR="old_value" -` - err := os.WriteFile(envFile, []byte(initialContent), 0644) - assert.NilError(t, err) - - err = ProcessEnvSecret("EXISTING_VAR", []byte("new_value"), false) - assert.ErrorContains(t, err, "already exists") - }) - - t.Run("ExistingWithForce", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - - t.Setenv("HOME", tmpDir) - - initialContent := `export EXISTING_VAR="old_value" -` - err := os.WriteFile(envFile, []byte(initialContent), 0644) - assert.NilError(t, err) - - err = ProcessEnvSecret("EXISTING_VAR", []byte("new_value"), true) - assert.NilError(t, err) - - content, err := os.ReadFile(envFile) - assert.NilError(t, err) - assert.Assert(t, len(content) > 0) - }) - - t.Run("MultipleCalls", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - - t.Setenv("HOME", tmpDir) - - err := ProcessEnvSecret("VAR1", []byte("value1"), false) - assert.NilError(t, err) - - err = ProcessEnvSecret("VAR2", []byte("value2"), false) - assert.NilError(t, err) - - content, err := os.ReadFile(envFile) - assert.NilError(t, err) - - contentStr := string(content) - assert.Assert(t, strings.Contains(contentStr, `export VAR1="value1"`)) - assert.Assert(t, strings.Contains(contentStr, `export VAR2="value2"`)) - }) - - t.Run("DuplicateCallWithoutForce", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - - t.Setenv("HOME", tmpDir) - - err := ProcessEnvSecret("DUPLICATE_VAR", []byte("value1"), false) - assert.NilError(t, err) - - err = ProcessEnvSecret("DUPLICATE_VAR", []byte("value2"), false) - assert.ErrorContains(t, err, "already exists") - - content, err := os.ReadFile(envFile) - assert.NilError(t, err) - assert.Assert(t, strings.Contains(string(content), `export DUPLICATE_VAR="value1"`)) - assert.Assert(t, !strings.Contains(string(content), "value2")) - }) - - t.Run("DuplicateCallWithForce", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - - t.Setenv("HOME", tmpDir) - - err := ProcessEnvSecret("DUPLICATE_VAR", []byte("value1"), false) - assert.NilError(t, err) - - err = ProcessEnvSecret("DUPLICATE_VAR", []byte("value2"), true) - assert.NilError(t, err) - - content, err := os.ReadFile(envFile) - assert.NilError(t, err) - - contentStr := string(content) - assert.Assert(t, strings.Contains(contentStr, `export DUPLICATE_VAR="value2"`)) - assert.Assert(t, !strings.Contains(contentStr, "value1")) - - lines := strings.Split(strings.TrimSpace(contentStr), "\n") - assert.Equal(t, 1, len(lines)) - }) -} - -func TestDeterministicOrdering(t *testing.T) { - t.Run("GetSecretKeysReturnsSorted", func(t *testing.T) { - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "zebra": {}, - "alpha": {}, - "charlie": {}, - "bravo": {}, - }, - } - - keys := GetSecretKeys(vault, []string{}) - assert.Equal(t, 4, len(keys)) - assert.Equal(t, "alpha", keys[0]) - assert.Equal(t, "bravo", keys[1]) - assert.Equal(t, "charlie", keys[2]) - assert.Equal(t, "zebra", keys[3]) - }) - - t.Run("MultipleRunsProduceSameOrder", func(t *testing.T) { - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "secret3": {}, - "secret1": {}, - "secret2": {}, - }, - } - - firstRun := GetSecretKeys(vault, []string{}) - secondRun := GetSecretKeys(vault, []string{}) - thirdRun := GetSecretKeys(vault, []string{}) - - assert.DeepEqual(t, firstRun, secondRun) - assert.DeepEqual(t, secondRun, thirdRun) - }) - - t.Run("RequestedKeysPreserveOrder", func(t *testing.T) { - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "zebra": {}, - "alpha": {}, - "charlie": {}, - }, - } - - requested := []string{"zebra", "alpha"} - keys := GetSecretKeys(vault, requested) - assert.DeepEqual(t, requested, keys) - }) -} - -func TestPerSecretForce(t *testing.T) { - t.Run("SecretForceOverwritesEnv", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - t.Setenv("HOME", tmpDir) - - existingContent := `export TEST_VAR="old_value" -` - err := os.WriteFile(envFile, []byte(existingContent), 0644) - assert.NilError(t, err) - - err = ProcessEnvSecret("TEST_VAR", []byte("new_value"), true) - assert.NilError(t, err) - - content, err := os.ReadFile(envFile) - assert.NilError(t, err) - - contentStr := string(content) - assert.Assert(t, strings.Contains(contentStr, `export TEST_VAR="new_value"`)) - assert.Assert(t, !strings.Contains(contentStr, "old_value")) - }) - - t.Run("SecretWithoutForceFailsOnExisting", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - t.Setenv("HOME", tmpDir) - - existingContent := `export TEST_VAR="old_value" -` - err := os.WriteFile(envFile, []byte(existingContent), 0644) - assert.NilError(t, err) - - err = ProcessEnvSecret("TEST_VAR", []byte("new_value"), false) - assert.ErrorContains(t, err, `environment variable "TEST_VAR" already exists, use --force to overwrite`) - }) - - t.Run("MixedForceInVault", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".zshenv") - t.Setenv("HOME", tmpDir) - - existingContent := `export VAR1="old1" -export VAR2="old2" -` - err := os.WriteFile(envFile, []byte(existingContent), 0644) - assert.NilError(t, err) - - err = ProcessEnvSecret("VAR1", []byte("new1"), true) - assert.NilError(t, err) - - err = ProcessEnvSecret("VAR2", []byte("new2"), false) - assert.ErrorContains(t, err, `environment variable "VAR2" already exists, use --force to overwrite`) - - content, err := os.ReadFile(envFile) - assert.NilError(t, err) - - contentStr := string(content) - assert.Assert(t, strings.Contains(contentStr, `export VAR1="new1"`)) - assert.Assert(t, strings.Contains(contentStr, `export VAR2="old2"`)) - }) -} - -func TestListVault(t *testing.T) { - t.Run("SortedOutput", func(t *testing.T) { - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "zebra": {Type: TypeSSH, Destination: "/z", Encrypted: "enc"}, - "alpha": {Type: TypeGeneric, Destination: "/a", Encrypted: "enc"}, - "mike": {Type: TypeEnv, Destination: "MY_VAR", Encrypted: "enc"}, - }, - } - - entries := ListVault(vault) - assert.Equal(t, 3, len(entries)) - assert.Equal(t, "alpha", entries[0].Name) - assert.Equal(t, "mike", entries[1].Name) - assert.Equal(t, "zebra", entries[2].Name) - }) - - t.Run("EmptyVault", func(t *testing.T) { - vault := &Vault{Secrets: map[string]VaultSecret{}} - entries := ListVault(vault) - assert.Equal(t, 0, len(entries)) - }) - - t.Run("PreservesTypeAndDestination", func(t *testing.T) { - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "ssh_key": {Type: TypeSSH, Destination: "/home/.ssh/id_rsa", Encrypted: "enc"}, - }, - } - - entries := ListVault(vault) - assert.Equal(t, 1, len(entries)) - assert.Equal(t, TypeSSH, entries[0].Type) - assert.Equal(t, "/home/.ssh/id_rsa", entries[0].Destination) - }) -} - -func TestLoadRawVault(t *testing.T) { - t.Run("DoesNotResolveDestinations", func(t *testing.T) { - vaultContent := ` -secrets: - ssh_key: - type: "ssh" - encrypted: "test$encrypted" - destination: "github.com/id_ed25519" -` - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err := os.WriteFile(vaultFile, []byte(vaultContent), 0600) - assert.NilError(t, err) - - vault, err := LoadRawVault(vaultFile) - assert.NilError(t, err) - assert.Equal(t, "github.com/id_ed25519", vault.Secrets["ssh_key"].Destination) - }) - - t.Run("DoesNotFillDefaults", func(t *testing.T) { - vaultContent := ` -secrets: - db_password: - encrypted: "test$encrypted" - destination: "/etc/db/password" -` - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err := os.WriteFile(vaultFile, []byte(vaultContent), 0600) - assert.NilError(t, err) - - vault, err := LoadRawVault(vaultFile) - assert.NilError(t, err) - assert.Equal(t, "", vault.Secrets["db_password"].Type) - assert.Equal(t, "", vault.Secrets["db_password"].Mode) - }) - - t.Run("EmptyVault", func(t *testing.T) { - vaultContent := `secrets: {}` - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err := os.WriteFile(vaultFile, []byte(vaultContent), 0600) - assert.NilError(t, err) - - vault, err := LoadRawVault(vaultFile) - assert.NilError(t, err) - assert.Equal(t, 0, len(vault.Secrets)) - }) - - t.Run("FileNotFound", func(t *testing.T) { - _, err := LoadRawVault("/nonexistent/vault.yaml") - assert.ErrorContains(t, err, "failed to read vault file") - }) -} - -func TestRotateVault(t *testing.T) { - t.Run("RotateDecryptsWithNewKey", func(t *testing.T) { - oldKey := []byte("old-master-key-for-testing") - newKey := []byte("new-master-key-for-testing") - - encrypted, err := Encrypt([]byte("my-secret-value"), oldKey) - assert.NilError(t, err) - - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "test_secret": { - Type: TypeGeneric, - Encrypted: encrypted, - Destination: "/etc/test", - }, - }, - } - - fileRefs, err := RotateVault(vault, oldKey, newKey) - assert.NilError(t, err) - assert.Equal(t, 0, len(fileRefs)) - - decrypted, err := Decrypt(vault.Secrets["test_secret"].Encrypted, newKey) - assert.NilError(t, err) - assert.Equal(t, "my-secret-value", string(decrypted)) - }) - - t.Run("OldKeyNoLongerWorks", func(t *testing.T) { - oldKey := []byte("old-master-key-for-testing") - newKey := []byte("new-master-key-for-testing") - - encrypted, err := Encrypt([]byte("my-secret-value"), oldKey) - assert.NilError(t, err) - - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "test_secret": { - Type: TypeGeneric, - Encrypted: encrypted, - Destination: "/etc/test", - }, - }, - } - - _, err = RotateVault(vault, oldKey, newKey) - assert.NilError(t, err) - - _, err = Decrypt(vault.Secrets["test_secret"].Encrypted, oldKey) - assert.ErrorContains(t, err, "cipher: message authentication failed") - }) - - t.Run("FileRefsReported", func(t *testing.T) { - oldKey := []byte("old-master-key-for-testing") - newKey := []byte("new-master-key-for-testing") - - encrypted, err := Encrypt([]byte("file-secret"), oldKey) - assert.NilError(t, err) - - encFile := filepath.Join(t.TempDir(), "secret.enc") - err = os.WriteFile(encFile, []byte(encrypted), 0600) - assert.NilError(t, err) - - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "file_secret": { - Type: TypeGeneric, - Encrypted: "file:" + encFile, - Destination: "/etc/test", - }, - }, - } - - fileRefs, err := RotateVault(vault, oldKey, newKey) - assert.NilError(t, err) - assert.Equal(t, 1, len(fileRefs)) - assert.Equal(t, "file_secret", fileRefs[0]) - - assert.Assert(t, !strings.HasPrefix(vault.Secrets["file_secret"].Encrypted, "file:")) - }) - - t.Run("MultipleSecrets", func(t *testing.T) { - oldKey := []byte("old-master-key-for-testing") - newKey := []byte("new-master-key-for-testing") - - enc1, err := Encrypt([]byte("secret-1"), oldKey) - assert.NilError(t, err) - enc2, err := Encrypt([]byte("secret-2"), oldKey) - assert.NilError(t, err) - - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "first": {Encrypted: enc1, Destination: "/etc/first"}, - "second": {Encrypted: enc2, Destination: "/etc/second"}, - }, - } - - _, err = RotateVault(vault, oldKey, newKey) - assert.NilError(t, err) - - dec1, err := Decrypt(vault.Secrets["first"].Encrypted, newKey) - assert.NilError(t, err) - assert.Equal(t, "secret-1", string(dec1)) - - dec2, err := Decrypt(vault.Secrets["second"].Encrypted, newKey) - assert.NilError(t, err) - assert.Equal(t, "secret-2", string(dec2)) - }) -} - -func TestSaveVault(t *testing.T) { - t.Run("RoundTrip", func(t *testing.T) { - vault := &Vault{ - Secrets: map[string]VaultSecret{ - "db_password": { - Type: TypeGeneric, - Encrypted: "salt$cipher", - Destination: "/etc/db/password", - Mode: "0o600", - }, - }, - } - - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - err := SaveVault(vaultFile, vault) - assert.NilError(t, err) - - loaded, err := LoadRawVault(vaultFile) - assert.NilError(t, err) - assert.Equal(t, 1, len(loaded.Secrets)) - assert.Equal(t, "salt$cipher", loaded.Secrets["db_password"].Encrypted) - assert.Equal(t, "/etc/db/password", loaded.Secrets["db_password"].Destination) - assert.Equal(t, TypeGeneric, loaded.Secrets["db_password"].Type) - assert.Equal(t, "0o600", loaded.Secrets["db_password"].Mode) - }) - - t.Run("FilePermissions", func(t *testing.T) { - vault := &Vault{Secrets: map[string]VaultSecret{}} - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - - err := SaveVault(vaultFile, vault) - assert.NilError(t, err) - - info, err := os.Stat(vaultFile) - assert.NilError(t, err) - assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) - }) - - t.Run("EmptyVault", func(t *testing.T) { - vault := &Vault{Secrets: map[string]VaultSecret{}} - vaultFile := filepath.Join(t.TempDir(), "vault.yaml") - - err := SaveVault(vaultFile, vault) - assert.NilError(t, err) - - loaded, err := LoadRawVault(vaultFile) - assert.NilError(t, err) - assert.Equal(t, 0, len(loaded.Secrets)) - }) -} diff --git a/internals/seed/apply.go b/internals/seed/apply.go new file mode 100644 index 0000000..a14770f --- /dev/null +++ b/internals/seed/apply.go @@ -0,0 +1,237 @@ +package seed + +import ( + "fmt" + "io" + "io/fs" + "os" + "slices" + + internalIO "github.com/kloudkit/ws-cli/internals/io" + "github.com/kloudkit/ws-cli/internals/secrets" + "github.com/kloudkit/ws-cli/internals/styles" +) + +type Options struct { + Source string + Force bool + Dests []string + MasterKey string + Out io.Writer + Styled bool +} + +type reporter struct { + out io.Writer + styled bool +} + +func (r reporter) seeded(dest string) { + if r.styled { + styles.PrintSuccess(r.out, fmt.Sprintf("Seeded [%s]", dest)) + return + } + + fmt.Fprintf(r.out, "Seeded [%s]\n", dest) +} + +func (r reporter) skip(dest, reason string) { + if r.styled { + styles.PrintWarning(r.out, fmt.Sprintf("Skipping [%s] (%s)", dest, reason)) + return + } + + fmt.Fprintf(r.out, "Skipping [%s] (%s)\n", dest, reason) +} + +func (r reporter) notice(dest string) { + message := fmt.Sprintf("[%s] runs next boot; ensure +x if executable", dest) + + if r.styled { + styles.PrintKeyValue(r.out, "Notice", message) + return + } + + fmt.Fprintf(r.out, "Notice %s\n", message) +} + +type keyResolver struct { + flag string + secrets map[string]string + key []byte + loaded bool + err error +} + +func (k *keyResolver) master() ([]byte, error) { + if !k.loaded { + k.key, k.err = secrets.ResolveMasterKey(k.flag) + k.loaded = true + } + + return k.key, k.err +} + +func (k *keyResolver) resolveNamed(name string) ([]byte, error) { + value, ok := k.secrets[name] + if !ok { + return nil, fmt.Errorf("secret %q not declared", name) + } + + master, err := k.master() + if err != nil { + return nil, err + } + + resolved, err := secrets.ResolveEncryptedValue(value) + if err != nil { + return nil, err + } + + return secrets.Decrypt(secrets.NormalizeEncrypted(resolved), master) +} + +func Apply(opts Options) error { + plan, err := BuildPlan(opts.Source, opts.Force) + if err != nil { + return err + } + + ops := plan.Ops + if len(opts.Dests) > 0 { + ops, err = plan.filterDests(opts.Dests) + if err != nil { + return err + } + } + + declared := map[string]string{} + if plan.Manifest != nil { + declared = plan.Manifest.Secrets + } + + keys := &keyResolver{flag: opts.MasterKey, secrets: declared} + rep := reporter{out: opts.Out, styled: opts.Styled} + + for _, op := range ops { + plan.applyOne(op, keys, rep) + } + + return nil +} + +func (p *Plan) applyOne(op ResolvedOp, keys *keyResolver, rep reporter) { + ancestor := nearestExistingAncestor(op.Dest) + if !ownsPath(ancestor) { + rep.skip(op.Dest, "destination not owned") + return + } + + if consumedNotice(op.Dest, p.Vars.Home) { + rep.notice(op.Dest) + } + + if !internalIO.CanOverride(op.Dest, op.Force) { + return + } + + content, mode, err := p.materialize(op, keys) + if err != nil { + rep.skip(op.Dest, err.Error()) + return + } + + anchor := chooseAnchor(op.Dest, p.Vars, ancestor) + if err := writeAtomic(anchor, op.Dest, content, mode); err != nil { + rep.skip(op.Dest, err.Error()) + return + } + + rep.seeded(op.Dest) +} + +func (p *Plan) materialize(op ResolvedOp, keys *keyResolver) ([]byte, fs.FileMode, error) { + raw, err := p.sourceBytes(op) + if err != nil { + return nil, 0, fmt.Errorf("no source available") + } + + mode, err := resolveMode(op, op.Secret || (op.Template && referencesSecrets(raw))) + if err != nil { + return nil, 0, err + } + + content, err := p.transform(op, raw, keys) + if err != nil { + return nil, 0, err + } + + switch op.Op { + case OpMerge: + if content, err = mergeContent(readExisting(op.Dest), content, op.Dest); err != nil { + return nil, 0, err + } + case OpAppend: + content = slices.Concat(readExisting(op.Dest), content) + case OpPrepend: + content = slices.Concat(content, readExisting(op.Dest)) + } + + return content, mode, nil +} + +func (p *Plan) transform(op ResolvedOp, raw []byte, keys *keyResolver) ([]byte, error) { + if op.Secret { + resolved, err := secrets.ResolveEncryptedValue(string(raw)) + if err != nil { + return nil, fmt.Errorf("secret source unresolved") + } + + master, err := keys.master() + if err != nil { + return nil, fmt.Errorf("master key unavailable") + } + + plain, err := secrets.Decrypt(secrets.NormalizeEncrypted(resolved), master) + if err != nil { + return nil, fmt.Errorf("decrypt failed") + } + + return plain, nil + } + + if op.Template { + return renderTemplate(raw, p.Vars, keys.resolveNamed) + } + + return raw, nil +} + +func (p *Plan) sourceBytes(op ResolvedOp) ([]byte, error) { + if op.Content != nil { + return []byte(*op.Content), nil + } + + return os.ReadFile(op.Source) +} + +func resolveMode(op ResolvedOp, secretBearing bool) (fs.FileMode, error) { + if secretBearing { + return 0o600, nil + } + + if op.Mode == "" { + return 0o644, nil + } + + return internalIO.ParseFileMode(op.Mode) +} + +func readExisting(dest string) []byte { + data, err := os.ReadFile(dest) + if err != nil { + return nil + } + + return data +} diff --git a/internals/seed/manifest.go b/internals/seed/manifest.go new file mode 100644 index 0000000..e39b2bd --- /dev/null +++ b/internals/seed/manifest.go @@ -0,0 +1,82 @@ +package seed + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" +) + +const ManifestName = ".seed.yaml" + +type Manifest struct { + Version string `yaml:"version"` + Secrets map[string]string `yaml:"secrets"` + Seeds map[string]SeedOp `yaml:"seeds"` +} + +func ManifestPath(source string) string { + return filepath.Join(source, ManifestName) +} + +func LoadManifest(path string) (*Manifest, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read manifest %q: %w", path, err) + } + + return ParseManifest(data) +} + +func ParseManifest(data []byte) (*Manifest, error) { + var manifest Manifest + if err := yaml.Unmarshal(data, &manifest); err != nil { + return nil, fmt.Errorf("failed to parse manifest: %w", err) + } + + if manifest.Version != "v1" { + return nil, fmt.Errorf("unsupported manifest version %q (expected \"v1\")", manifest.Version) + } + + for name, value := range manifest.Secrets { + if err := validateSecretValue(name, value); err != nil { + return nil, err + } + } + + for dest, op := range manifest.Seeds { + if !op.hasBehavior() { + return nil, fmt.Errorf("seed %q: a copy-only entry is not allowed (use the mirror tier)", dest) + } + + if op.Op == "" { + op.Op = OpCopy + manifest.Seeds[dest] = op + } + + if err := validateOp(dest, op); err != nil { + return nil, err + } + } + + return &manifest, nil +} + +func validateSecretValue(name, value string) error { + if value != "" && (strings.HasPrefix(value, "file:") || strings.Contains(value, "$")) { + return nil + } + + return fmt.Errorf("secret %q: expected ciphertext or file: ref", name) +} + +func validateOp(dest string, op SeedOp) error { + switch op.Op { + case OpCopy, OpMerge, OpAppend, OpPrepend: + return nil + } + + return fmt.Errorf("seed %q: unknown op %q", dest, op.Op) +} diff --git a/internals/seed/manifest_test.go b/internals/seed/manifest_test.go new file mode 100644 index 0000000..940ce46 --- /dev/null +++ b/internals/seed/manifest_test.go @@ -0,0 +1,51 @@ +package seed + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +func TestParseManifest(t *testing.T) { + t.Run("UnknownVersionRejected", func(t *testing.T) { + _, err := ParseManifest([]byte("version: v2\n")) + assert.ErrorContains(t, err, "unsupported manifest version") + }) + + t.Run("MissingVersionRejected", func(t *testing.T) { + _, err := ParseManifest([]byte("seeds: {}\n")) + assert.ErrorContains(t, err, "unsupported manifest version") + }) + + t.Run("CopyOnlyEntryRejected", func(t *testing.T) { + _, err := ParseManifest([]byte("version: v1\nseeds:\n /tmp/x:\n op: copy\n")) + assert.ErrorContains(t, err, "copy-only entry is not allowed") + }) + + t.Run("EmptyEntryRejected", func(t *testing.T) { + _, err := ParseManifest([]byte("version: v1\nseeds:\n /tmp/x: {}\n")) + assert.ErrorContains(t, err, "copy-only entry is not allowed") + }) + + t.Run("SecretValueInvalidRejected", func(t *testing.T) { + _, err := ParseManifest([]byte("version: v1\nsecrets:\n TOKEN: plainnodollar\n")) + assert.ErrorContains(t, err, `secret "TOKEN": expected ciphertext or file: ref`) + }) + + t.Run("SecretValueFileRefAccepted", func(t *testing.T) { + manifest, err := ParseManifest([]byte("version: v1\nsecrets:\n TOKEN: file:/run/secrets/token\n")) + assert.NilError(t, err) + assert.Equal(t, manifest.Secrets["TOKEN"], "file:/run/secrets/token") + }) + + t.Run("BehaviorEntryAccepted", func(t *testing.T) { + manifest, err := ParseManifest([]byte("version: v1\nseeds:\n /tmp/x:\n secret: true\n")) + assert.NilError(t, err) + assert.Equal(t, manifest.Seeds["/tmp/x"].Op, OpCopy) + }) + + t.Run("UnknownOpRejected", func(t *testing.T) { + _, err := ParseManifest([]byte("version: v1\nseeds:\n /tmp/x:\n op: smash\n")) + assert.ErrorContains(t, err, `unknown op "smash"`) + }) +} diff --git a/internals/seed/merge.go b/internals/seed/merge.go new file mode 100644 index 0000000..df6b66d --- /dev/null +++ b/internals/seed/merge.go @@ -0,0 +1,97 @@ +package seed + +import ( + "bytes" + "encoding/json" + "fmt" + "path/filepath" + "strings" + + toml "github.com/pelletier/go-toml/v2" + "gopkg.in/yaml.v3" +) + +type codec struct { + unmarshal func([]byte, any) error + marshal func(any) ([]byte, error) +} + +func codecFor(dest string) (codec, error) { + switch strings.ToLower(filepath.Ext(dest)) { + case ".json": + return codec{json.Unmarshal, marshalJSON}, nil + case ".yaml", ".yml": + return codec{yaml.Unmarshal, yaml.Marshal}, nil + case ".toml": + return codec{toml.Unmarshal, toml.Marshal}, nil + } + + return codec{}, fmt.Errorf("cannot infer merge format from %q", dest) +} + +func marshalJSON(v any) ([]byte, error) { + var buffer bytes.Buffer + encoder := json.NewEncoder(&buffer) + encoder.SetIndent("", " ") + + if err := encoder.Encode(v); err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +func mergeContent(existing, fragment []byte, dest string) ([]byte, error) { + c, err := codecFor(dest) + if err != nil { + return nil, err + } + + dst := map[string]any{} + if len(bytes.TrimSpace(existing)) > 0 { + if err := c.unmarshal(existing, &dst); err != nil { + return nil, fmt.Errorf("failed to decode existing %q", dest) + } + } + + src := map[string]any{} + if len(bytes.TrimSpace(fragment)) > 0 { + if err := c.unmarshal(fragment, &src); err != nil { + return nil, fmt.Errorf("failed to decode fragment for %q", dest) + } + } + + if err := deepMerge(dst, src); err != nil { + return nil, err + } + + return c.marshal(dst) +} + +func deepMerge(dst, src map[string]any) error { + for key, srcVal := range src { + dstVal, exists := dst[key] + if !exists { + dst[key] = srcVal + continue + } + + srcMap, srcIsMap := srcVal.(map[string]any) + dstMap, dstIsMap := dstVal.(map[string]any) + + if srcIsMap != dstIsMap { + return fmt.Errorf("merge conflict at key %q: type mismatch", key) + } + + if srcIsMap { + if err := deepMerge(dstMap, srcMap); err != nil { + return err + } + continue + } + + dst[key] = srcVal + } + + return nil +} diff --git a/internals/seed/merge_test.go b/internals/seed/merge_test.go new file mode 100644 index 0000000..15d94f1 --- /dev/null +++ b/internals/seed/merge_test.go @@ -0,0 +1,70 @@ +package seed + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +func decodeBack(t *testing.T, data []byte, dest string) map[string]any { + t.Helper() + + c, err := codecFor(dest) + assert.NilError(t, err) + + out := map[string]any{} + assert.NilError(t, c.unmarshal(data, &out)) + + return out +} + +func TestMergeContent(t *testing.T) { + tests := []struct { + name string + dest string + existing string + fragment string + }{ + {"json", "config.json", `{"keep":1,"list":[1,2,3]}`, `{"list":[9],"add":true}`}, + {"yaml", "config.yaml", "keep: 1\nlist: [1, 2, 3]\n", "list: [9]\nadd: true\n"}, + {"toml", "config.toml", "keep = 1\nlist = [1, 2, 3]\n", "list = [9]\nadd = true\n"}, + } + + for _, tt := range tests { + t.Run("ListReplace/"+tt.name, func(t *testing.T) { + merged, err := mergeContent([]byte(tt.existing), []byte(tt.fragment), tt.dest) + assert.NilError(t, err) + + out := decodeBack(t, merged, tt.dest) + + list, ok := out["list"].([]any) + assert.Assert(t, ok) + assert.Equal(t, len(list), 1) + assert.Assert(t, out["keep"] != nil) + assert.Assert(t, out["add"] != nil) + }) + } + + t.Run("ScalarVsMapConflict", func(t *testing.T) { + _, err := mergeContent([]byte(`{"k":"scalar"}`), []byte(`{"k":{"nested":1}}`), "config.json") + assert.ErrorContains(t, err, "merge conflict at key") + }) + + t.Run("MapVsScalarConflict", func(t *testing.T) { + _, err := mergeContent([]byte(`{"k":{"nested":1}}`), []byte(`{"k":"scalar"}`), "config.json") + assert.ErrorContains(t, err, "merge conflict at key") + }) + + t.Run("JSONFloatNormalization", func(t *testing.T) { + merged, err := mergeContent([]byte(`{"n":1}`), []byte(`{"n":2}`), "config.json") + assert.NilError(t, err) + + out := decodeBack(t, merged, "config.json") + assert.Equal(t, out["n"], float64(2)) + }) + + t.Run("UnknownExtensionRejected", func(t *testing.T) { + _, err := mergeContent([]byte("a"), []byte("b"), "config.ini") + assert.ErrorContains(t, err, "cannot infer merge format") + }) +} diff --git a/internals/seed/mirror.go b/internals/seed/mirror.go new file mode 100644 index 0000000..25c74d1 --- /dev/null +++ b/internals/seed/mirror.go @@ -0,0 +1,45 @@ +package seed + +import ( + "io/fs" + "os" + "path/filepath" +) + +func walkMirror(source string) (map[string]string, error) { + mirror := map[string]string{} + + info, err := os.Stat(source) + if err != nil || !info.IsDir() { + return mirror, nil + } + + err = filepath.WalkDir(source, func(p string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() { + return nil + } + + rel, err := filepath.Rel(source, p) + if err != nil { + return err + } + + if rel == ManifestName { + return nil + } + + mirror[string(os.PathSeparator)+rel] = p + + return nil + }) + + if err != nil { + return nil, err + } + + return mirror, nil +} diff --git a/internals/seed/op.go b/internals/seed/op.go new file mode 100644 index 0000000..46a9670 --- /dev/null +++ b/internals/seed/op.go @@ -0,0 +1,23 @@ +package seed + +type Op string + +const ( + OpCopy Op = "copy" + OpMerge Op = "merge" + OpAppend Op = "append" + OpPrepend Op = "prepend" +) + +type SeedOp struct { + Mode string `yaml:"mode"` + Content *string `yaml:"content"` + Secret bool `yaml:"secret"` + Op Op `yaml:"op"` + Template bool `yaml:"template"` + Force bool `yaml:"force"` +} + +func (o SeedOp) hasBehavior() bool { + return o.Secret || o.Mode != "" || (o.Op != "" && o.Op != OpCopy) || o.Template +} diff --git a/internals/seed/ownership.go b/internals/seed/ownership.go new file mode 100644 index 0000000..b97bc3d --- /dev/null +++ b/internals/seed/ownership.go @@ -0,0 +1,79 @@ +package seed + +import ( + "os" + "path/filepath" + "strings" + "syscall" +) + +var consumedDirs = []string{ + ".ws/startup.d", + ".ws/ca.d", + ".ws/session.d", + ".ws/features.d", +} + +func nearestExistingAncestor(dest string) string { + current := filepath.Dir(dest) + + for { + if _, err := os.Lstat(current); err == nil { + return current + } + + parent := filepath.Dir(current) + if parent == current { + return current + } + + current = parent + } +} + +func ownsPath(p string) bool { + info, err := os.Stat(p) + if err != nil { + return false + } + + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return false + } + + return int(stat.Uid) == os.Geteuid() +} + +func isUnder(p, root string) bool { + if root == "" { + return false + } + + root = filepath.Clean(root) + p = filepath.Clean(p) + + return p == root || strings.HasPrefix(p, root+string(os.PathSeparator)) +} + +func chooseAnchor(dest string, vars Vars, ancestor string) string { + if isUnder(dest, vars.Home) { + return vars.Home + } + + if isUnder(dest, vars.ServerRoot) { + return vars.ServerRoot + } + + return ancestor +} + +func consumedNotice(dest, home string) bool { + for _, dir := range consumedDirs { + if isUnder(dest, filepath.Join(home, dir)) { + return true + } + } + + return false +} diff --git a/internals/seed/resolve.go b/internals/seed/resolve.go new file mode 100644 index 0000000..1c6e703 --- /dev/null +++ b/internals/seed/resolve.go @@ -0,0 +1,185 @@ +package seed + +import ( + "fmt" + "os" + "os/user" + "path/filepath" + "sort" + "strings" + + "github.com/kloudkit/ws-cli/internals/config" + "github.com/kloudkit/ws-cli/internals/env" + internalIO "github.com/kloudkit/ws-cli/internals/io" + "github.com/kloudkit/ws-cli/internals/path" +) + +type Vars struct { + Home string + User string + ServerRoot string +} + +type ResolvedOp struct { + Dest string + Source string + Content *string + Mode string + Secret bool + Op Op + Template bool + Force bool +} + +type Plan struct { + Source string + Vars Vars + Manifest *Manifest + Ops []ResolvedOp +} + +func resolveVars() Vars { + username := "" + if u, err := user.Current(); err == nil { + username = u.Username + } + + serverRoot, _ := config.Resolve("server", "root") + + return Vars{ + Home: env.Home(), + User: username, + ServerRoot: serverRoot, + } +} + +func (v Vars) expand(value string) (string, error) { + replaced := strings.NewReplacer( + "${ws_home}", v.Home, + "${ws_user}", v.User, + "${ws_server_root}", v.ServerRoot, + ).Replace(value) + + return path.Expand(replaced) +} + +func ResolveSource(flag string) (string, error) { + if flag == "" { + resolved, err := config.Resolve("seed", "source") + if err != nil { + return "", err + } + + flag = resolved + } + + return path.Expand(flag) +} + +func BuildPlan(source string, force bool) (*Plan, error) { + vars := resolveVars() + + var manifest *Manifest + if manifestPath := ManifestPath(source); internalIO.FileExists(manifestPath) { + loaded, err := LoadManifest(manifestPath) + if err != nil { + return nil, err + } + + manifest = loaded + } + + ops, err := buildPlan(source, manifest, vars, force) + if err != nil { + return nil, err + } + + return &Plan{Source: source, Vars: vars, Manifest: manifest, Ops: ops}, nil +} + +func buildPlan(source string, manifest *Manifest, vars Vars, force bool) ([]ResolvedOp, error) { + plan := map[string]ResolvedOp{} + + mirror, err := walkMirror(source) + if err != nil { + return nil, err + } + + for dest, src := range mirror { + plan[dest] = ResolvedOp{Dest: dest, Source: src, Op: OpCopy, Force: force} + } + + if manifest != nil { + for rawDest, op := range manifest.Seeds { + dest, err := vars.expand(rawDest) + if err != nil { + return nil, fmt.Errorf("seed %q: %w", rawDest, err) + } + + resolved := ResolvedOp{ + Dest: dest, + Content: op.Content, + Mode: op.Mode, + Secret: op.Secret, + Op: op.Op, + Template: op.Template, + Force: force || op.Force, + } + + if op.Content == nil { + resolved.Source = rhymingSource(source, dest) + } + + plan[dest] = resolved + } + } + + dests := make([]string, 0, len(plan)) + for dest := range plan { + dests = append(dests, dest) + } + sort.Strings(dests) + + ops := make([]ResolvedOp, 0, len(plan)) + for _, dest := range dests { + ops = append(ops, plan[dest]) + } + + return ops, nil +} + +func rhymingSource(source, dest string) string { + return filepath.Join(source, strings.TrimPrefix(dest, string(os.PathSeparator))) +} + +func (p *Plan) filterDests(args []string) ([]ResolvedOp, error) { + wanted := map[string]bool{} + for _, arg := range args { + resolved, err := p.Vars.expand(arg) + if err != nil { + return nil, err + } + + wanted[resolved] = true + } + + filtered := make([]ResolvedOp, 0, len(args)) + for _, op := range p.Ops { + if wanted[op.Dest] { + filtered = append(filtered, op) + delete(wanted, op.Dest) + } + } + + if len(wanted) > 0 { + missing := make([]string, 0, len(wanted)) + for dest := range wanted { + missing = append(missing, dest) + } + sort.Strings(missing) + + return nil, fmt.Errorf("no seed entry for: %s", strings.Join(missing, ", ")) + } + + return filtered, nil +} diff --git a/internals/seed/seed_test.go b/internals/seed/seed_test.go new file mode 100644 index 0000000..001c8c5 --- /dev/null +++ b/internals/seed/seed_test.go @@ -0,0 +1,491 @@ +package seed + +import ( + "bytes" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" + "testing" + + "github.com/kloudkit/ws-cli/internals/secrets" + "gotest.tools/v3/assert" +) + +const testMaster = "ws-seed-test-master-key-0123456789" + +func setEnv(t *testing.T, home string) { + t.Setenv("HOME", home) + t.Setenv("WS__INTERNAL_ENV_REFERENCE", filepath.Join(t.TempDir(), "absent.yaml")) +} + +func write(t *testing.T, path, content string) { + t.Helper() + assert.NilError(t, os.MkdirAll(filepath.Dir(path), 0o755)) + assert.NilError(t, os.WriteFile(path, []byte(content), 0o644)) +} + +func writeManifest(t *testing.T, source, body string) { + t.Helper() + write(t, filepath.Join(source, ManifestName), "version: v1\n"+body) +} + +func rhyming(source, dest string) string { + return filepath.Join(source, strings.TrimPrefix(dest, "/")) +} + +func apply(t *testing.T, opts Options) string { + t.Helper() + var buffer bytes.Buffer + opts.Out = &buffer + assert.NilError(t, Apply(opts)) + return buffer.String() +} + +func encrypt(t *testing.T, plaintext, key string) string { + t.Helper() + master, err := secrets.ResolveMasterKey(key) + assert.NilError(t, err) + encrypted, err := secrets.Encrypt([]byte(plaintext), master) + assert.NilError(t, err) + return encrypted +} + +func mode(t *testing.T, path string) os.FileMode { + t.Helper() + info, err := os.Stat(path) + assert.NilError(t, err) + return info.Mode().Perm() +} + +func TestApplyMirror(t *testing.T) { + t.Run("Verbatim", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "file.txt") + + write(t, rhyming(source, dest), "verbatim\n") + + apply(t, Options{Source: source}) + + got, err := os.ReadFile(dest) + assert.NilError(t, err) + assert.Equal(t, string(got), "verbatim\n") + }) + + t.Run("ManifestNotProjected", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + writeManifest(t, source, "") + + plan, err := BuildPlan(source, false) + assert.NilError(t, err) + assert.Equal(t, len(plan.Ops), 0) + }) +} + +func TestApplyInline(t *testing.T) { + t.Run("Literal", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "out.txt") + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o640\"\n content: \"inline\\n\"\n", dest)) + + apply(t, Options{Source: source}) + + got, err := os.ReadFile(dest) + assert.NilError(t, err) + assert.Equal(t, string(got), "inline\n") + assert.Equal(t, mode(t, dest), os.FileMode(0o640)) + }) + + t.Run("CopyOnlyMissingSourceWarns", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "ghost.txt") + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o644\"\n", dest)) + + output := apply(t, Options{Source: source}) + + assert.Assert(t, !fileExists(dest)) + assert.Assert(t, strings.Contains(output, "Skipping ["+dest+"] (no source available)")) + }) +} + +func TestApplyOps(t *testing.T) { + t.Run("Append", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "list.txt") + write(t, dest, "base\n") + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n op: append\n force: true\n content: \"added\\n\"\n", dest)) + + apply(t, Options{Source: source}) + + got, err := os.ReadFile(dest) + assert.NilError(t, err) + assert.Equal(t, string(got), "base\nadded\n") + }) + + t.Run("Prepend", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "list.txt") + write(t, dest, "base\n") + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n op: prepend\n force: true\n content: \"added\\n\"\n", dest)) + + apply(t, Options{Source: source}) + + got, err := os.ReadFile(dest) + assert.NilError(t, err) + assert.Equal(t, string(got), "added\nbase\n") + }) + + t.Run("MergeJSON", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "config.json") + write(t, dest, `{"a":1,"list":[1,2,3]}`) + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n op: merge\n force: true\n content: '{\"list\":[9],\"b\":2}'\n", dest)) + + apply(t, Options{Source: source}) + + out := decodeBack(t, []byte(readFile(t, dest)), dest) + assert.Equal(t, out["a"], float64(1)) + assert.Equal(t, out["b"], float64(2)) + list, ok := out["list"].([]any) + assert.Assert(t, ok) + assert.Equal(t, len(list), 1) + }) + + t.Run("MergeScalarVsMapConflictLeavesDestUnchanged", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "config.json") + write(t, dest, `{"k":"scalar"}`) + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n op: merge\n force: true\n content: '{\"k\":{\"nested\":1}}'\n", dest)) + + output := apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), `{"k":"scalar"}`) + assert.Assert(t, strings.Contains(output, "merge conflict at key")) + }) +} + +func TestApplyTemplate(t *testing.T) { + t.Run("WsTokensAndSecret", func(t *testing.T) { + home := t.TempDir() + setEnv(t, home) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "rendered.txt") + + ciphertext := encrypt(t, "S3CR3T", testMaster) + writeManifest(t, source, fmt.Sprintf( + "secrets:\n TOK: %s\nseeds:\n %s:\n template: true\n content: \"${ws_home}|${ws_user}|${secrets.TOK}\\n\"\n", + ciphertext, dest, + )) + + apply(t, Options{Source: source, MasterKey: testMaster}) + + current, err := user.Current() + assert.NilError(t, err) + assert.Equal(t, readFile(t, dest), fmt.Sprintf("%s|%s|S3CR3T\n", home, current.Username)) + }) + + t.Run("UnknownTokenFailsLoud", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "rendered.txt") + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n template: true\n content: \"${bogus}\\n\"\n", dest)) + + output := apply(t, Options{Source: source}) + + assert.Assert(t, !fileExists(dest)) + assert.Assert(t, strings.Contains(output, "unknown template token ${bogus}")) + }) + + t.Run("SecretBearingFloor", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "secret.txt") + + ciphertext := encrypt(t, "TOP-SECRET-VALUE", testMaster) + writeManifest(t, source, fmt.Sprintf( + "secrets:\n TOK: %s\nseeds:\n %s:\n template: true\n content: \"${secrets.TOK}\\n\"\n", + ciphertext, dest, + )) + + output := apply(t, Options{Source: source, MasterKey: testMaster}) + + assert.Equal(t, readFile(t, dest), "TOP-SECRET-VALUE\n") + assert.Equal(t, mode(t, dest), os.FileMode(0o600)) + assert.Assert(t, !strings.Contains(output, "TOP-SECRET-VALUE")) + }) + + t.Run("SecretBearingFailureOutputScrubbed", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "conflict.json") + write(t, dest, `{"key":"scalar"}`) + + ciphertext := encrypt(t, "TOP-SECRET-VALUE", testMaster) + writeManifest(t, source, fmt.Sprintf( + "secrets:\n TOK: %s\nseeds:\n %s:\n template: true\n op: merge\n force: true\n content: '{\"key\":{\"n\":\"${secrets.TOK}\"}}'\n", + ciphertext, dest, + )) + + output := apply(t, Options{Source: source, MasterKey: testMaster}) + + assert.Equal(t, readFile(t, dest), `{"key":"scalar"}`) + assert.Assert(t, strings.Contains(output, "merge conflict at key")) + assert.Assert(t, !strings.Contains(output, "TOP-SECRET-VALUE")) + }) +} + +func TestApplySecrets(t *testing.T) { + t.Run("WholeFile", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "id_key") + + write(t, rhyming(source, dest), encrypt(t, "PRIVATE-KEY-BODY\n", testMaster)) + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n secret: true\n", dest)) + + apply(t, Options{Source: source, MasterKey: testMaster}) + + assert.Equal(t, readFile(t, dest), "PRIVATE-KEY-BODY\n") + assert.Equal(t, mode(t, dest), os.FileMode(0o600)) + }) + + t.Run("FailClosedOnBadDecrypt", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "id_key") + + write(t, rhyming(source, dest), encrypt(t, "PRIVATE", "a-totally-different-master-key-99")) + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n secret: true\n", dest)) + + output := apply(t, Options{Source: source, MasterKey: testMaster}) + + assert.Assert(t, !fileExists(dest)) + assert.Assert(t, strings.Contains(output, "Skipping ["+dest+"] (decrypt failed)")) + assert.Assert(t, !strings.Contains(output, "PRIVATE")) + }) + + t.Run("SecretFreeManifestNeedsNoKey", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "plain.txt") + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o644\"\n content: \"plain\\n\"\n", dest)) + + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "plain\n") + }) + + t.Run("MissingKeyFailsClosedButNonSecretApplies", func(t *testing.T) { + setEnv(t, t.TempDir()) + t.Setenv("WS_SECRETS_MASTER_KEY", "") + source := t.TempDir() + target := t.TempDir() + secretDest := filepath.Join(target, "secret.txt") + plainDest := filepath.Join(target, "plain.txt") + + write(t, rhyming(source, secretDest), encrypt(t, "X", testMaster)) + writeManifest(t, source, fmt.Sprintf( + "seeds:\n %s:\n secret: true\n %s:\n mode: \"0o644\"\n content: \"plain\\n\"\n", + secretDest, plainDest, + )) + + output := apply(t, Options{Source: source}) + + assert.Assert(t, !fileExists(secretDest)) + assert.Equal(t, readFile(t, plainDest), "plain\n") + assert.Assert(t, strings.Contains(output, "master key unavailable")) + }) +} + +func TestApplyOwnership(t *testing.T) { + t.Run("OwnedAncestorAllows", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "nested", "deep", "out.txt") + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o644\"\n content: \"ok\\n\"\n", dest)) + + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "ok\n") + }) + + t.Run("NonOwnedAncestorSkips", func(t *testing.T) { + if os.Geteuid() == 0 { + t.Skip("running as root: every ancestor is writable") + } + + setEnv(t, t.TempDir()) + source := t.TempDir() + dest := "/etc/ws-seed-test-should-not-write/out.txt" + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o644\"\n content: \"x\\n\"\n", dest)) + + output := apply(t, Options{Source: source}) + + assert.Assert(t, !fileExists(dest)) + assert.Assert(t, strings.Contains(output, "Skipping ["+dest+"] (destination not owned)")) + }) +} + +func TestApplyPrecedence(t *testing.T) { + t.Run("ManifestWinsOverBare", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "shared.txt") + + write(t, rhyming(source, dest), "BARE\n") + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o600\"\n content: \"MANIFEST\\n\"\n", dest)) + + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "MANIFEST\n") + assert.Equal(t, mode(t, dest), os.FileMode(0o600)) + }) + + t.Run("SecretSuppressesRawProjection", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "key") + + ciphertext := encrypt(t, "PLAINTEXT\n", testMaster) + write(t, rhyming(source, dest), ciphertext) + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n secret: true\n", dest)) + + apply(t, Options{Source: source, MasterKey: testMaster}) + + got := readFile(t, dest) + assert.Equal(t, got, "PLAINTEXT\n") + assert.Assert(t, got != ciphertext) + assert.Equal(t, mode(t, dest), os.FileMode(0o600)) + }) +} + +func TestApplyForceMatrix(t *testing.T) { + manifest := func(dest, op, content string) string { + return fmt.Sprintf("seeds:\n %s:\n op: %s\n mode: \"0o644\"\n content: %s\n", dest, op, content) + } + + t.Run("AbsentWrites", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "f.txt") + + writeManifest(t, source, manifest(dest, "copy", "\"v1\\n\"")) + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "v1\n") + }) + + t.Run("ExistsForceFalseSkips", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "f.txt") + write(t, dest, "orig\n") + + writeManifest(t, source, manifest(dest, "copy", "\"v2\\n\"")) + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "orig\n") + }) + + t.Run("ExistsForceTrueOverwrites", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "f.txt") + write(t, dest, "orig\n") + + writeManifest(t, source, manifest(dest, "copy", "\"v2\\n\"")) + apply(t, Options{Source: source, Force: true}) + + assert.Equal(t, readFile(t, dest), "v2\n") + }) + + t.Run("MergeForceFalseExistsSkips", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "f.json") + write(t, dest, `{"a":1}`) + + writeManifest(t, source, manifest(dest, "merge", "'{\"b\":2}'")) + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), `{"a":1}`) + }) +} + +func TestWriteAtomicSymlink(t *testing.T) { + t.Run("EscapingComponentRefused", func(t *testing.T) { + anchor := t.TempDir() + outside := t.TempDir() + assert.NilError(t, os.Symlink(outside, filepath.Join(anchor, "evil"))) + + dest := filepath.Join(anchor, "evil", "file.txt") + err := writeAtomic(anchor, dest, []byte("x"), 0o644) + + assert.Assert(t, err != nil) + assert.Assert(t, !fileExists(filepath.Join(outside, "file.txt"))) + }) + + t.Run("FinalComponentSymlinkRefused", func(t *testing.T) { + anchor := t.TempDir() + outside := filepath.Join(t.TempDir(), "target.txt") + assert.NilError(t, os.Symlink(outside, filepath.Join(anchor, "link"))) + + dest := filepath.Join(anchor, "link") + err := writeAtomic(anchor, dest, []byte("x"), 0o644) + + assert.ErrorContains(t, err, "refusing to write through symlink") + assert.Assert(t, !fileExists(outside)) + }) +} + +func fileExists(path string) bool { + _, err := os.Lstat(path) + return err == nil +} + +func readFile(t *testing.T, path string) string { + t.Helper() + data, err := os.ReadFile(path) + assert.NilError(t, err) + return string(data) +} diff --git a/internals/seed/template.go b/internals/seed/template.go new file mode 100644 index 0000000..5f827ce --- /dev/null +++ b/internals/seed/template.go @@ -0,0 +1,55 @@ +package seed + +import ( + "fmt" + "regexp" + "strings" +) + +const secretsPrefix = "secrets." + +var tokenRe = regexp.MustCompile(`\$\{([^}]*)\}`) + +func referencesSecrets(content []byte) bool { + return strings.Contains(string(content), "${"+secretsPrefix) +} + +func renderTemplate(content []byte, vars Vars, secret func(string) ([]byte, error)) ([]byte, error) { + var failure error + + rendered := tokenRe.ReplaceAllFunc(content, func(match []byte) []byte { + if failure != nil { + return nil + } + + token := string(match[2 : len(match)-1]) + + switch token { + case "ws_home": + return []byte(vars.Home) + case "ws_user": + return []byte(vars.User) + case "ws_server_root": + return []byte(vars.ServerRoot) + } + + if name, ok := strings.CutPrefix(token, secretsPrefix); ok { + value, err := secret(name) + if err != nil { + failure = err + return nil + } + + return value + } + + failure = fmt.Errorf("unknown template token ${%s}", token) + return nil + }) + + if failure != nil { + return nil, failure + } + + return rendered, nil +} diff --git a/internals/seed/write.go b/internals/seed/write.go new file mode 100644 index 0000000..da9e90f --- /dev/null +++ b/internals/seed/write.go @@ -0,0 +1,77 @@ +package seed + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" +) + +const tempSuffix = ".ws-seed.tmp" + +func writeAtomic(anchor, dest string, content []byte, mode fs.FileMode) error { + root, err := os.OpenRoot(anchor) + if err != nil { + return fmt.Errorf("failed to open root %q: %w", anchor, err) + } + defer root.Close() + + rel, err := filepath.Rel(anchor, dest) + if err != nil { + return fmt.Errorf("failed to resolve relative path: %w", err) + } + + if info, err := root.Lstat(rel); err == nil && info.Mode()&fs.ModeSymlink != 0 { + return fmt.Errorf("refusing to write through symlink %q", dest) + } + + dirMode := mode&0o077 | 0o700 + if relDir := filepath.Dir(rel); relDir != "." { + if err := root.MkdirAll(relDir, dirMode); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + } + + tmp := rel + tempSuffix + if err := writeTemp(root, tmp, content, mode); err != nil { + return err + } + + if err := root.Rename(tmp, rel); err != nil { + root.Remove(tmp) + return fmt.Errorf("failed to rename into place: %w", err) + } + + return nil +} + +func writeTemp(root *os.Root, tmp string, content []byte, mode fs.FileMode) (err error) { + file, err := root.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_EXCL, mode) + if errors.Is(err, fs.ErrExist) { + root.Remove(tmp) + file, err = root.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_EXCL, mode) + } + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + + defer func() { + if cerr := file.Close(); cerr != nil && err == nil { + err = fmt.Errorf("failed to close temp file: %w", cerr) + } + if err != nil { + root.Remove(tmp) + } + }() + + if _, err = file.Write(content); err != nil { + return fmt.Errorf("failed to write temp file: %w", err) + } + + if err = file.Chmod(mode); err != nil { + return fmt.Errorf("failed to set mode: %w", err) + } + + return nil +} From c8584fde9ce231d0f6e23fb72fd9937993c0586d Mon Sep 17 00:00:00 2001 From: Dov Benyomin Sohacheski Date: Tue, 30 Jun 2026 16:03:57 +0300 Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=A8=20Add=20seed=20`op:=20block`=20an?= =?UTF-8?q?d=20apply=20review=20hardening?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a `blockinfile`-style managed marker block op (markers auto-inserted, body replaced idempotently, force-gate bypass, fixed `>>> ws-seed >>>` text with a tunable `comment:` prefix, fail-closed on malformed markers). Fold in code-review fixes: preserve an existing dest's mode on in-place ops, count inline `content` as a behavior, return nonzero on failure (`SilenceUsage` on apply), zero the cached master key at run end, propagate the real source-read error, fix the consumed-dir notice ordering and report `force:false` skips, and disable JSON HTML escaping on merge. --- cmd/seed/apply.go | 7 +- internals/seed/apply.go | 72 +++++++++--- internals/seed/block.go | 97 ++++++++++++++++ internals/seed/block_test.go | 197 ++++++++++++++++++++++++++++++++ internals/seed/manifest.go | 11 +- internals/seed/manifest_test.go | 18 +++ internals/seed/merge.go | 1 + internals/seed/op.go | 8 +- internals/seed/resolve.go | 2 + internals/seed/seed_test.go | 51 +++++++-- 10 files changed, 435 insertions(+), 29 deletions(-) create mode 100644 internals/seed/block.go create mode 100644 internals/seed/block_test.go diff --git a/cmd/seed/apply.go b/cmd/seed/apply.go index 1fc1644..c51c300 100644 --- a/cmd/seed/apply.go +++ b/cmd/seed/apply.go @@ -10,9 +10,10 @@ import ( ) var applyCmd = &cobra.Command{ - Use: "apply [dest...]", - Short: "Project seed content onto the filesystem", - RunE: runApply, + Use: "apply [dest...]", + Short: "Project seed content onto the filesystem", + SilenceUsage: true, + RunE: runApply, } func runApply(cmd *cobra.Command, args []string) error { diff --git a/internals/seed/apply.go b/internals/seed/apply.go index a14770f..6ab3175 100644 --- a/internals/seed/apply.go +++ b/internals/seed/apply.go @@ -1,6 +1,8 @@ package seed import ( + "bytes" + "errors" "fmt" "io" "io/fs" @@ -72,6 +74,12 @@ func (k *keyResolver) master() ([]byte, error) { return k.key, k.err } +func (k *keyResolver) zero() { + for i := range k.key { + k.key[i] = 0 + } +} + func (k *keyResolver) resolveNamed(name string) ([]byte, error) { value, ok := k.secrets[name] if !ok { @@ -111,49 +119,73 @@ func Apply(opts Options) error { } keys := &keyResolver{flag: opts.MasterKey, secrets: declared} + defer keys.zero() rep := reporter{out: opts.Out, styled: opts.Styled} + failures := 0 for _, op := range ops { - plan.applyOne(op, keys, rep) + if err := plan.applyOne(op, keys, rep); err != nil { + failures++ + } + } + + if failures > 0 { + noun := "entries" + if failures == 1 { + noun = "entry" + } + + return fmt.Errorf("%d seed %s failed to apply", failures, noun) } return nil } -func (p *Plan) applyOne(op ResolvedOp, keys *keyResolver, rep reporter) { +func (p *Plan) applyOne(op ResolvedOp, keys *keyResolver, rep reporter) error { ancestor := nearestExistingAncestor(op.Dest) if !ownsPath(ancestor) { rep.skip(op.Dest, "destination not owned") - return + return fmt.Errorf("destination not owned") } - if consumedNotice(op.Dest, p.Vars.Home) { - rep.notice(op.Dest) - } - - if !internalIO.CanOverride(op.Dest, op.Force) { - return + if op.Op != OpBlock && !internalIO.CanOverride(op.Dest, op.Force) { + rep.skip(op.Dest, "exists") + return nil } content, mode, err := p.materialize(op, keys) if err != nil { rep.skip(op.Dest, err.Error()) - return + return err + } + + if op.Op == OpBlock && bytes.Equal(content, readExisting(op.Dest)) { + rep.seeded(op.Dest) + return nil } anchor := chooseAnchor(op.Dest, p.Vars, ancestor) if err := writeAtomic(anchor, op.Dest, content, mode); err != nil { rep.skip(op.Dest, err.Error()) - return + return err + } + + if consumedNotice(op.Dest, p.Vars.Home) { + rep.notice(op.Dest) } rep.seeded(op.Dest) + return nil } func (p *Plan) materialize(op ResolvedOp, keys *keyResolver) ([]byte, fs.FileMode, error) { raw, err := p.sourceBytes(op) if err != nil { - return nil, 0, fmt.Errorf("no source available") + if errors.Is(err, fs.ErrNotExist) { + return nil, 0, fmt.Errorf("no source available") + } + + return nil, 0, fmt.Errorf("source unreadable: %w", err) } mode, err := resolveMode(op, op.Secret || (op.Template && referencesSecrets(raw))) @@ -175,6 +207,10 @@ func (p *Plan) materialize(op ResolvedOp, keys *keyResolver) ([]byte, fs.FileMod content = slices.Concat(readExisting(op.Dest), content) case OpPrepend: content = slices.Concat(content, readExisting(op.Dest)) + case OpBlock: + if content, err = ensureBlock(readExisting(op.Dest), content, op.Comment); err != nil { + return nil, 0, err + } } return content, mode, nil @@ -220,11 +256,17 @@ func resolveMode(op ResolvedOp, secretBearing bool) (fs.FileMode, error) { return 0o600, nil } - if op.Mode == "" { - return 0o644, nil + if op.Mode != "" { + return internalIO.ParseFileMode(op.Mode) + } + + if op.Op.inPlace() { + if info, err := os.Stat(op.Dest); err == nil { + return info.Mode().Perm(), nil + } } - return internalIO.ParseFileMode(op.Mode) + return 0o644, nil } func readExisting(dest string) []byte { diff --git a/internals/seed/block.go b/internals/seed/block.go new file mode 100644 index 0000000..b757002 --- /dev/null +++ b/internals/seed/block.go @@ -0,0 +1,97 @@ +package seed + +import ( + "bytes" + "fmt" + "strings" +) + +const ( + blockBeginBody = ">>> ws-seed >>>" + blockEndBody = "<<< ws-seed <<<" + defaultComment = "#" + blockBegin = defaultComment + " " + blockBeginBody + blockEnd = defaultComment + " " + blockEndBody +) + +func blockMarkers(comment string) (string, string) { + if comment == "" { + comment = defaultComment + } + + return comment + " " + blockBeginBody, comment + " " + blockEndBody +} + +func renderBlock(body []byte, begin, end string) []byte { + var buffer bytes.Buffer + buffer.WriteString(begin) + buffer.WriteByte('\n') + + if len(body) > 0 { + buffer.Write(body) + if body[len(body)-1] != '\n' { + buffer.WriteByte('\n') + } + } + + buffer.WriteString(end) + buffer.WriteByte('\n') + + return buffer.Bytes() +} + +func appendBlock(existing, block []byte) []byte { + if len(existing) == 0 { + return block + } + + var buffer bytes.Buffer + buffer.Write(existing) + if existing[len(existing)-1] != '\n' { + buffer.WriteByte('\n') + } + buffer.Write(block) + + return buffer.Bytes() +} + +func ensureBlock(existing, body []byte, comment string) ([]byte, error) { + beginMarker, endMarker := blockMarkers(comment) + block := renderBlock(body, beginMarker, endMarker) + lines := bytes.SplitAfter(existing, []byte("\n")) + + begin, end := -1, -1 + for i, line := range lines { + switch strings.TrimRight(string(line), "\n") { + case beginMarker: + if begin >= 0 { + return nil, fmt.Errorf("malformed managed block: duplicate begin marker") + } + begin = i + case endMarker: + if end >= 0 { + return nil, fmt.Errorf("malformed managed block: duplicate end marker") + } + end = i + } + } + + if begin < 0 && end < 0 { + return appendBlock(existing, block), nil + } + + if begin < 0 || end < 0 || end < begin { + return nil, fmt.Errorf("malformed managed block: markers out of order") + } + + var buffer bytes.Buffer + for _, line := range lines[:begin] { + buffer.Write(line) + } + buffer.Write(block) + for _, line := range lines[end+1:] { + buffer.Write(line) + } + + return buffer.Bytes(), nil +} diff --git a/internals/seed/block_test.go b/internals/seed/block_test.go new file mode 100644 index 0000000..5030bc3 --- /dev/null +++ b/internals/seed/block_test.go @@ -0,0 +1,197 @@ +package seed + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "gotest.tools/v3/assert" +) + +func blockManifest(dest, content string) string { + return fmt.Sprintf("seeds:\n %s:\n op: block\n content: %s\n", dest, content) +} + +func TestEnsureBlock(t *testing.T) { + tests := []struct { + name string + existing string + body string + comment string + want string + err string + }{ + { + name: "AppendsWhenAbsent", + body: "export FOO=1\n", + want: "# >>> ws-seed >>>\nexport FOO=1\n# <<< ws-seed <<<\n", + }, + { + name: "CustomCommentPrefix", + body: "const x = 1;\n", + comment: "//", + want: "// >>> ws-seed >>>\nconst x = 1;\n// <<< ws-seed <<<\n", + }, + { + name: "CustomCommentReplacesBody", + existing: "// >>> ws-seed >>>\nold\n// <<< ws-seed <<<\n", + body: "new\n", + comment: "//", + want: "// >>> ws-seed >>>\nnew\n// <<< ws-seed <<<\n", + }, + { + name: "AppendsAfterExisting", + existing: "base\n", + body: "line\n", + want: "base\n# >>> ws-seed >>>\nline\n# <<< ws-seed <<<\n", + }, + { + name: "AddsNewlineBeforeMarker", + existing: "base", + body: "line\n", + want: "base\n# >>> ws-seed >>>\nline\n# <<< ws-seed <<<\n", + }, + { + name: "ReplacesBodyPreservingSurround", + existing: "head\n# >>> ws-seed >>>\nold\n# <<< ws-seed <<<\ntail\n", + body: "new\n", + want: "head\n# >>> ws-seed >>>\nnew\n# <<< ws-seed <<<\ntail\n", + }, + { + name: "DuplicateBeginRejected", + existing: "# >>> ws-seed >>>\n# >>> ws-seed >>>\n# <<< ws-seed <<<\n", + body: "x\n", + err: "duplicate begin marker", + }, + { + name: "EndBeforeBeginRejected", + existing: "# <<< ws-seed <<<\nbody\n# >>> ws-seed >>>\n", + body: "x\n", + err: "markers out of order", + }, + { + name: "BeginWithoutEndRejected", + existing: "# >>> ws-seed >>>\norphan\n", + body: "x\n", + err: "markers out of order", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ensureBlock([]byte(tt.existing), []byte(tt.body), tt.comment) + + if tt.err != "" { + assert.ErrorContains(t, err, tt.err) + return + } + + assert.NilError(t, err) + assert.Equal(t, string(got), tt.want) + }) + } +} + +func TestApplyBlock(t *testing.T) { + t.Run("CreatesWhenAbsent", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "zshenv") + + writeManifest(t, source, blockManifest(dest, "\"export FOO=1\\n\"")) + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "# >>> ws-seed >>>\nexport FOO=1\n# <<< ws-seed <<<\n") + }) + + t.Run("AppendsToExistingWithoutForce", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "zshenv") + write(t, dest, "base\n") + + writeManifest(t, source, blockManifest(dest, "\"line\\n\"")) + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "base\n# >>> ws-seed >>>\nline\n# <<< ws-seed <<<\n") + }) + + t.Run("Idempotent", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "zshenv") + write(t, dest, "base\n") + + writeManifest(t, source, blockManifest(dest, "\"line\\n\"")) + apply(t, Options{Source: source}) + first := readFile(t, dest) + apply(t, Options{Source: source}) + second := readFile(t, dest) + + assert.Equal(t, first, second) + assert.Equal(t, strings.Count(second, blockBegin), 1) + }) + + t.Run("ReplacesBodyOnChange", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "zshenv") + write(t, dest, "head\n") + + writeManifest(t, source, blockManifest(dest, "\"v1\\n\"")) + apply(t, Options{Source: source}) + + writeManifest(t, source, blockManifest(dest, "\"v2\\n\"")) + apply(t, Options{Source: source}) + + got := readFile(t, dest) + assert.Equal(t, got, "head\n# >>> ws-seed >>>\nv2\n# <<< ws-seed <<<\n") + assert.Equal(t, strings.Count(got, blockBegin), 1) + }) + + t.Run("PreservesExistingMode", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "zshenv") + write(t, dest, "base\n") + assert.NilError(t, os.Chmod(dest, 0o600)) + + writeManifest(t, source, blockManifest(dest, "\"line\\n\"")) + apply(t, Options{Source: source}) + + assert.Equal(t, mode(t, dest), os.FileMode(0o600)) + }) + + t.Run("CustomCommentMarker", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "app.js") + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n op: block\n comment: \"//\"\n content: \"const x = 1;\\n\"\n", dest)) + apply(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "// >>> ws-seed >>>\nconst x = 1;\n// <<< ws-seed <<<\n") + }) + + t.Run("MalformedMarkersFailClosed", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "zshenv") + write(t, dest, "# >>> ws-seed >>>\norphan\n") + + writeManifest(t, source, blockManifest(dest, "\"line\\n\"")) + output := applyErr(t, Options{Source: source}) + + assert.Equal(t, readFile(t, dest), "# >>> ws-seed >>>\norphan\n") + assert.Assert(t, strings.Contains(output, "malformed managed block")) + }) +} diff --git a/internals/seed/manifest.go b/internals/seed/manifest.go index e39b2bd..3cb174e 100644 --- a/internals/seed/manifest.go +++ b/internals/seed/manifest.go @@ -74,9 +74,14 @@ func validateSecretValue(name, value string) error { func validateOp(dest string, op SeedOp) error { switch op.Op { - case OpCopy, OpMerge, OpAppend, OpPrepend: - return nil + case OpCopy, OpMerge, OpAppend, OpPrepend, OpBlock: + default: + return fmt.Errorf("seed %q: unknown op %q", dest, op.Op) + } + + if op.Comment != "" && op.Op != OpBlock { + return fmt.Errorf("seed %q: comment is only valid with op: block", dest) } - return fmt.Errorf("seed %q: unknown op %q", dest, op.Op) + return nil } diff --git a/internals/seed/manifest_test.go b/internals/seed/manifest_test.go index 940ce46..1b92105 100644 --- a/internals/seed/manifest_test.go +++ b/internals/seed/manifest_test.go @@ -44,8 +44,26 @@ func TestParseManifest(t *testing.T) { assert.Equal(t, manifest.Seeds["/tmp/x"].Op, OpCopy) }) + t.Run("InlineContentEntryAccepted", func(t *testing.T) { + manifest, err := ParseManifest([]byte("version: v1\nseeds:\n /tmp/x:\n content: \"hi\\n\"\n")) + assert.NilError(t, err) + assert.Equal(t, *manifest.Seeds["/tmp/x"].Content, "hi\n") + assert.Equal(t, manifest.Seeds["/tmp/x"].Op, OpCopy) + }) + + t.Run("BlockOpAccepted", func(t *testing.T) { + manifest, err := ParseManifest([]byte("version: v1\nseeds:\n /tmp/x:\n op: block\n content: \"hi\\n\"\n")) + assert.NilError(t, err) + assert.Equal(t, manifest.Seeds["/tmp/x"].Op, OpBlock) + }) + t.Run("UnknownOpRejected", func(t *testing.T) { _, err := ParseManifest([]byte("version: v1\nseeds:\n /tmp/x:\n op: smash\n")) assert.ErrorContains(t, err, `unknown op "smash"`) }) + + t.Run("CommentOnNonBlockRejected", func(t *testing.T) { + _, err := ParseManifest([]byte("version: v1\nseeds:\n /tmp/x:\n op: append\n comment: \"//\"\n content: \"x\\n\"\n")) + assert.ErrorContains(t, err, "comment is only valid with op: block") + }) } diff --git a/internals/seed/merge.go b/internals/seed/merge.go index df6b66d..2e4f566 100644 --- a/internals/seed/merge.go +++ b/internals/seed/merge.go @@ -32,6 +32,7 @@ func codecFor(dest string) (codec, error) { func marshalJSON(v any) ([]byte, error) { var buffer bytes.Buffer encoder := json.NewEncoder(&buffer) + encoder.SetEscapeHTML(false) encoder.SetIndent("", " ") if err := encoder.Encode(v); err != nil { diff --git a/internals/seed/op.go b/internals/seed/op.go index 46a9670..07a0804 100644 --- a/internals/seed/op.go +++ b/internals/seed/op.go @@ -7,6 +7,7 @@ const ( OpMerge Op = "merge" OpAppend Op = "append" OpPrepend Op = "prepend" + OpBlock Op = "block" ) type SeedOp struct { @@ -16,8 +17,13 @@ type SeedOp struct { Op Op `yaml:"op"` Template bool `yaml:"template"` Force bool `yaml:"force"` + Comment string `yaml:"comment"` } func (o SeedOp) hasBehavior() bool { - return o.Secret || o.Mode != "" || (o.Op != "" && o.Op != OpCopy) || o.Template + return o.Secret || o.Mode != "" || (o.Op != "" && o.Op != OpCopy) || o.Template || o.Content != nil +} + +func (o Op) inPlace() bool { + return o == OpMerge || o == OpAppend || o == OpPrepend || o == OpBlock } diff --git a/internals/seed/resolve.go b/internals/seed/resolve.go index 1c6e703..57dc3f0 100644 --- a/internals/seed/resolve.go +++ b/internals/seed/resolve.go @@ -29,6 +29,7 @@ type ResolvedOp struct { Op Op Template bool Force bool + Comment string } type Plan struct { @@ -124,6 +125,7 @@ func buildPlan(source string, manifest *Manifest, vars Vars, force bool) ([]Reso Op: op.Op, Template: op.Template, Force: force || op.Force, + Comment: op.Comment, } if op.Content == nil { diff --git a/internals/seed/seed_test.go b/internals/seed/seed_test.go index 001c8c5..a3575f7 100644 --- a/internals/seed/seed_test.go +++ b/internals/seed/seed_test.go @@ -43,6 +43,14 @@ func apply(t *testing.T, opts Options) string { return buffer.String() } +func applyErr(t *testing.T, opts Options) string { + t.Helper() + var buffer bytes.Buffer + opts.Out = &buffer + assert.Assert(t, Apply(opts) != nil) + return buffer.String() +} + func encrypt(t *testing.T, plaintext, key string) string { t.Helper() master, err := secrets.ResolveMasterKey(key) @@ -111,11 +119,25 @@ func TestApplyInline(t *testing.T) { writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o644\"\n", dest)) - output := apply(t, Options{Source: source}) + output := applyErr(t, Options{Source: source}) assert.Assert(t, !fileExists(dest)) assert.Assert(t, strings.Contains(output, "Skipping ["+dest+"] (no source available)")) }) + + t.Run("SourceUnreadableReported", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "out.txt") + + assert.NilError(t, os.MkdirAll(rhyming(source, dest), 0o755)) + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o644\"\n", dest)) + + output := applyErr(t, Options{Source: source}) + + assert.Assert(t, strings.Contains(output, "source unreadable")) + }) } func TestApplyOps(t *testing.T) { @@ -170,6 +192,21 @@ func TestApplyOps(t *testing.T) { assert.Equal(t, len(list), 1) }) + t.Run("MergePreservesExistingMode", func(t *testing.T) { + setEnv(t, t.TempDir()) + source := t.TempDir() + target := t.TempDir() + dest := filepath.Join(target, "config.json") + write(t, dest, `{"a":1}`) + assert.NilError(t, os.Chmod(dest, 0o600)) + + writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n op: merge\n force: true\n content: '{\"b\":2}'\n", dest)) + + apply(t, Options{Source: source}) + + assert.Equal(t, mode(t, dest), os.FileMode(0o600)) + }) + t.Run("MergeScalarVsMapConflictLeavesDestUnchanged", func(t *testing.T) { setEnv(t, t.TempDir()) source := t.TempDir() @@ -179,7 +216,7 @@ func TestApplyOps(t *testing.T) { writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n op: merge\n force: true\n content: '{\"k\":{\"nested\":1}}'\n", dest)) - output := apply(t, Options{Source: source}) + output := applyErr(t, Options{Source: source}) assert.Equal(t, readFile(t, dest), `{"k":"scalar"}`) assert.Assert(t, strings.Contains(output, "merge conflict at key")) @@ -215,7 +252,7 @@ func TestApplyTemplate(t *testing.T) { writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n template: true\n content: \"${bogus}\\n\"\n", dest)) - output := apply(t, Options{Source: source}) + output := applyErr(t, Options{Source: source}) assert.Assert(t, !fileExists(dest)) assert.Assert(t, strings.Contains(output, "unknown template token ${bogus}")) @@ -253,7 +290,7 @@ func TestApplyTemplate(t *testing.T) { ciphertext, dest, )) - output := apply(t, Options{Source: source, MasterKey: testMaster}) + output := applyErr(t, Options{Source: source, MasterKey: testMaster}) assert.Equal(t, readFile(t, dest), `{"key":"scalar"}`) assert.Assert(t, strings.Contains(output, "merge conflict at key")) @@ -286,7 +323,7 @@ func TestApplySecrets(t *testing.T) { write(t, rhyming(source, dest), encrypt(t, "PRIVATE", "a-totally-different-master-key-99")) writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n secret: true\n", dest)) - output := apply(t, Options{Source: source, MasterKey: testMaster}) + output := applyErr(t, Options{Source: source, MasterKey: testMaster}) assert.Assert(t, !fileExists(dest)) assert.Assert(t, strings.Contains(output, "Skipping ["+dest+"] (decrypt failed)")) @@ -320,7 +357,7 @@ func TestApplySecrets(t *testing.T) { secretDest, plainDest, )) - output := apply(t, Options{Source: source}) + output := applyErr(t, Options{Source: source}) assert.Assert(t, !fileExists(secretDest)) assert.Equal(t, readFile(t, plainDest), "plain\n") @@ -353,7 +390,7 @@ func TestApplyOwnership(t *testing.T) { writeManifest(t, source, fmt.Sprintf("seeds:\n %s:\n mode: \"0o644\"\n content: \"x\\n\"\n", dest)) - output := apply(t, Options{Source: source}) + output := applyErr(t, Options{Source: source}) assert.Assert(t, !fileExists(dest)) assert.Assert(t, strings.Contains(output, "Skipping ["+dest+"] (destination not owned)"))