From d452149b01bbab02bbb1beac25ec451e59858622 Mon Sep 17 00:00:00 2001 From: Dmitry Smirnov Date: Tue, 30 Jun 2026 10:57:23 -0400 Subject: [PATCH 1/5] PPSC-1037 implement SSO through oauth2 device flow --- internal/auth/auth.go | 187 ++++++++++++- internal/auth/browser.go | 64 +++++ internal/auth/browser_test.go | 31 +++ internal/auth/client.go | 7 +- internal/auth/device.go | 385 +++++++++++++++++++++++++++ internal/auth/device_test.go | 211 +++++++++++++++ internal/auth/oauth_provider_test.go | 115 ++++++++ internal/auth/tokenstore.go | 266 ++++++++++++++++++ internal/auth/tokenstore_test.go | 193 ++++++++++++++ internal/cmd/auth.go | 59 ++-- internal/cmd/auth_flow_test.go | 271 +++++++++++++++++++ internal/cmd/auth_login.go | 141 ++++++++++ internal/cmd/auth_logout.go | 64 +++++ internal/cmd/auth_test.go | 10 +- internal/cmd/auth_whoami.go | 91 +++++++ internal/cmd/root.go | 78 +++++- 16 files changed, 2140 insertions(+), 33 deletions(-) create mode 100644 internal/auth/browser.go create mode 100644 internal/auth/browser_test.go create mode 100644 internal/auth/device.go create mode 100644 internal/auth/device_test.go create mode 100644 internal/auth/oauth_provider_test.go create mode 100644 internal/auth/tokenstore.go create mode 100644 internal/auth/tokenstore_test.go create mode 100644 internal/cmd/auth_flow_test.go create mode 100644 internal/cmd/auth_login.go create mode 100644 internal/cmd/auth_logout.go create mode 100644 internal/cmd/auth_whoami.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go index ad495df9..0694b40f 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "net/http" + "os" "strings" "sync" "time" @@ -40,8 +41,14 @@ type JWTCredentials struct { } // AuthProvider manages authentication tokens with automatic refresh. -// It supports both JWT authentication and legacy Basic authentication. -// For JWT auth, tokens are automatically refreshed when within 5 minutes of expiry. +// It supports three modes: +// - SSO / device flow (OAuth2): tokens obtained via `armis-cli auth login`, +// persisted in the token store, refreshed via the refresh-token grant. +// - JWT client credentials: client_id/client_secret exchanged for a JWT. +// - Legacy Basic auth: a static --token. +// +// For JWT and SSO auth, tokens are automatically refreshed when within 5 minutes +// of expiry. type AuthProvider struct { config AuthConfig credentials *JWTCredentials @@ -50,6 +57,83 @@ type AuthProvider struct { isLegacy bool // true if using Basic auth (--token) cachedRegion string // memoized region from disk cache (loaded once) regionLoaded bool // true if cachedRegion has been loaded from disk + + // SSO / device-flow mode (mutually exclusive with isLegacy and JWT mode). + isOAuth bool + stored *StoredToken + deviceClient *DeviceClient + tokenStore *TokenStore + env string // environment key (API base URL) the token is stored under +} + +// AuthMethod identifies which credential path the provider is using, for display +// by `armis-cli auth whoami`. +type AuthMethod string + +const ( + // AuthMethodSSO is the browser-based device-flow (OAuth2) login. + AuthMethodSSO AuthMethod = "sso" + // AuthMethodClientCredentials is the client_id/client_secret JWT exchange. + AuthMethodClientCredentials AuthMethod = "client-credentials" + // AuthMethodBasic is the legacy static-token Basic auth. + AuthMethodBasic AuthMethod = "basic" +) + +// NewProviderFromStored builds an AuthProvider backed by an existing device-flow +// token for the given environment (API base URL). The provider transparently +// refreshes the access token via the refresh token when it nears expiry, +// persisting the rotated pair back to the store under that same environment. +func NewProviderFromStored(store *TokenStore, deviceClient *DeviceClient, env string, stored *StoredToken) (*AuthProvider, error) { + if store == nil || deviceClient == nil || stored == nil { + return nil, fmt.Errorf("token store, device client, and stored token are required") + } + if env == "" { + return nil, fmt.Errorf("env is required") + } + return &AuthProvider{ + isOAuth: true, + stored: stored, + deviceClient: deviceClient, + tokenStore: store, + env: env, + }, nil +} + +// AuthMethod returns the credential path in use. +func (p *AuthProvider) AuthMethod() AuthMethod { + switch { + case p.isOAuth: + return AuthMethodSSO + case p.isLegacy: + return AuthMethodBasic + default: + return AuthMethodClientCredentials + } +} + +// Identity returns the subject (user/service identifier) for the current +// session. It is populated for SSO auth and empty otherwise. +func (p *AuthProvider) Identity() string { + p.mu.RLock() + defer p.mu.RUnlock() + if p.isOAuth && p.stored != nil { + return p.stored.Subject + } + return "" +} + +// Expiry returns the access-token expiry for SSO/JWT auth, or the zero time when +// not applicable (Basic auth). +func (p *AuthProvider) Expiry() time.Time { + p.mu.RLock() + defer p.mu.RUnlock() + if p.isOAuth && p.stored != nil { + return p.stored.ExpiresAt + } + if p.credentials != nil { + return p.credentials.ExpiresAt + } + return time.Time{} } // NewAuthProvider creates an AuthProvider from configuration. @@ -92,7 +176,8 @@ func NewAuthProvider(config AuthConfig) (*AuthProvider, error) { return nil, fmt.Errorf("tenant ID required: use --tenant-id flag or ARMIS_TENANT_ID environment variable") } } else { - return nil, fmt.Errorf("authentication required: set ARMIS_CLIENT_ID / ARMIS_CLIENT_SECRET (or use --client-id / --client-secret) for JWT auth, or ARMIS_API_TOKEN (--token) for legacy auth") + return nil, fmt.Errorf("authentication required: set ARMIS_CLIENT_ID / ARMIS_CLIENT_SECRET " + + "(or use --client-id / --client-secret) for JWT auth, or ARMIS_API_TOKEN (--token) for legacy auth") } return p, nil @@ -107,6 +192,15 @@ func (p *AuthProvider) GetAuthorizationHeader(ctx context.Context) (string, erro return "Basic " + p.config.Token, nil } + if p.isOAuth { + if err := p.refreshOAuthIfNeeded(ctx); err != nil { + return "", err + } + p.mu.RLock() + defer p.mu.RUnlock() + return "Bearer " + p.stored.AccessToken, nil + } + // Refresh JWT if needed if err := p.refreshIfNeeded(ctx); err != nil { return "", fmt.Errorf("failed to refresh token: %w", err) @@ -126,6 +220,15 @@ func (p *AuthProvider) GetTenantID(ctx context.Context) (string, error) { return p.config.TenantID, nil } + if p.isOAuth { + if err := p.refreshOAuthIfNeeded(ctx); err != nil { + return "", err + } + p.mu.RLock() + defer p.mu.RUnlock() + return p.stored.TenantID, nil + } + if err := p.refreshIfNeeded(ctx); err != nil { return "", fmt.Errorf("failed to refresh token: %w", err) } @@ -143,6 +246,15 @@ func (p *AuthProvider) GetRegion(ctx context.Context) (string, error) { return "", nil // Legacy auth doesn't have region } + if p.isOAuth { + if err := p.refreshOAuthIfNeeded(ctx); err != nil { + return "", err + } + p.mu.RLock() + defer p.mu.RUnlock() + return p.stored.Region, nil + } + if err := p.refreshIfNeeded(ctx); err != nil { return "", fmt.Errorf("failed to refresh token: %w", err) } @@ -171,6 +283,16 @@ func (p *AuthProvider) GetRawToken(ctx context.Context) (string, error) { return p.config.Token, nil } + if p.isOAuth { + if err := p.refreshOAuthIfNeeded(ctx); err != nil { + return "", err + } + p.mu.RLock() + defer p.mu.RUnlock() + // armis:ignore cwe:522 reason:returning token to caller is the API contract; token is used for authenticated API calls + return p.stored.AccessToken, nil + } + // Refresh JWT if needed if err := p.refreshIfNeeded(ctx); err != nil { return "", fmt.Errorf("failed to refresh token: %w", err) @@ -300,6 +422,65 @@ func (p *AuthProvider) refreshIfNeeded(ctx context.Context) error { return p.exchangeCredentials(ctx) } +// refreshOAuthIfNeeded refreshes the device-flow access token via the refresh +// token when it is within 5 minutes of expiry, persisting the rotated pair back +// to the token store. Uses double-checked locking to avoid concurrent refreshes. +func (p *AuthProvider) refreshOAuthIfNeeded(ctx context.Context) error { + p.mu.RLock() + needsRefresh := p.stored == nil || + time.Until(p.stored.ExpiresAt) < 5*time.Minute + p.mu.RUnlock() + if !needsRefresh { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + // Double-check: another goroutine may have refreshed while we waited. + if p.stored != nil && time.Until(p.stored.ExpiresAt) >= 5*time.Minute { + return nil + } + if p.stored == nil || p.stored.RefreshToken == "" { + return fmt.Errorf("your session has expired; run 'armis-cli auth login' to sign in again") + } + + refreshed, err := p.deviceClient.Refresh(ctx, p.stored.RefreshToken, p.stored.ClientID) + if err != nil { + var oerr *OAuthError + if asOAuthError(err, &oerr) && (oerr.Code == errInvalidGrant || oerr.Code == errExpiredToken) { + return fmt.Errorf("your session has expired; run 'armis-cli auth login' to sign in again") + } + return fmt.Errorf("failed to refresh session: %w", err) + } + + // Carry forward identity fields the refresh response may not echo. + if refreshed.TenantID == "" { + refreshed.TenantID = p.stored.TenantID + } + if refreshed.Subject == "" { + refreshed.Subject = p.stored.Subject + } + if refreshed.Role == "" { + refreshed.Role = p.stored.Role + } + if refreshed.Region == "" { + refreshed.Region = p.stored.Region + } + if refreshed.ClientID == "" { + refreshed.ClientID = p.stored.ClientID + } + + p.stored = refreshed + if err := p.tokenStore.Save(p.env, refreshed); err != nil { + // Non-fatal: the in-memory token is valid for this process even if we + // could not persist it. A later invocation will refresh again. + if p.config.Debug { + fmt.Fprintf(os.Stderr, "[DEBUG] failed to persist refreshed token: %v\n", err) + } + } + return nil +} + // jwtClaims represents the relevant claims from a JWT. type jwtClaims struct { CustomerID string // maps to tenant_id diff --git a/internal/auth/browser.go b/internal/auth/browser.go new file mode 100644 index 00000000..eb810cd0 --- /dev/null +++ b/internal/auth/browser.go @@ -0,0 +1,64 @@ +// Package auth provides authentication for the Armis API. +// This file opens the system browser for the device-flow verification page. +package auth + +import ( + "fmt" + "net/url" + "os/exec" + "runtime" +) + +// browserOpener is overridable so the device-login flow can be exercised +// without spawning a real browser. See SetBrowserOpener. +var browserOpener = openBrowserCmd + +// SetBrowserOpener replaces the function used to launch the browser and returns +// a function that restores the previous opener. It is intended for tests (which +// must not spawn a real browser); production code uses the default opener. +func SetBrowserOpener(fn func(string) error) (restore func()) { + prev := browserOpener + browserOpener = fn + return func() { browserOpener = prev } +} + +// OpenBrowser attempts to open the given URL in the user's default browser. +// It returns an error when no opener is available (headless server, SSH, locked +// down terminal); callers fall back to printing the URL and user_code. +// +// Only http(s) URLs are accepted, so a malformed verification URI cannot be +// turned into the execution of an arbitrary local handler. +func OpenBrowser(rawURL string) error { + parsed, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + if parsed.Scheme != schemeHTTP && parsed.Scheme != schemeHTTPS { + return fmt.Errorf("refusing to open non-http(s) URL") + } + return browserOpener(rawURL) +} + +// openBrowserCmd launches the platform-specific browser opener. +// +// armis:ignore cwe:78 reason:URL is validated as http(s) by OpenBrowser and passed as a single argv element (no shell), not interpolated into a command string +func openBrowserCmd(rawURL string) error { + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + // #nosec G204 -- rawURL is validated as http(s) by OpenBrowser and passed as a separate argv element (no shell) + cmd = exec.Command("open", rawURL) + case "windows": + // #nosec G204 -- rawURL is validated as http(s) by OpenBrowser and passed as a separate argv element (no shell) + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", rawURL) + default: // linux, *bsd, etc. + // #nosec G204 -- rawURL is validated as http(s) by OpenBrowser and passed as a separate argv element (no shell) + cmd = exec.Command("xdg-open", rawURL) + } + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to open browser: %w", err) + } + // Reap the child so it does not become a zombie; the browser detaches itself. + go func() { _ = cmd.Wait() }() //nolint:errcheck // fire-and-forget + return nil +} diff --git a/internal/auth/browser_test.go b/internal/auth/browser_test.go new file mode 100644 index 00000000..f73fc825 --- /dev/null +++ b/internal/auth/browser_test.go @@ -0,0 +1,31 @@ +package auth + +import "testing" + +func TestOpenBrowserRejectsNonHTTP(t *testing.T) { + restore := SetBrowserOpener(func(string) error { + t.Error("opener should not be called for a rejected scheme") + return nil + }) + defer restore() + + for _, bad := range []string{"file:///etc/passwd", "javascript:alert(1)", "ftp://host/x", "not a url"} { + if err := OpenBrowser(bad); err == nil { + t.Errorf("OpenBrowser(%q) = nil, want error", bad) + } + } +} + +func TestOpenBrowserAllowsHTTPS(t *testing.T) { + var got string + restore := SetBrowserOpener(func(u string) error { got = u; return nil }) + defer restore() + + const url = "https://moose.armis.com/oauth2/device/verify?user_code=ABCD-EFGH" + if err := OpenBrowser(url); err != nil { + t.Fatalf("OpenBrowser: %v", err) + } + if got != url { + t.Errorf("opener got %q, want %q", got, url) + } +} diff --git a/internal/auth/client.go b/internal/auth/client.go index ea9aa4ee..6153cb05 100644 --- a/internal/auth/client.go +++ b/internal/auth/client.go @@ -22,6 +22,11 @@ const ( // ProductionBaseURL is the default Armis API endpoint (US region / primary). ProductionBaseURL = "https://moose.armis.com" + + // schemeHTTPS / schemeHTTP are URL scheme literals shared across the auth + // package's HTTPS-enforcement checks. + schemeHTTPS = "https" + schemeHTTP = "http" ) // RegionalBaseURL returns the Armis API base URL for the given region code. @@ -84,7 +89,7 @@ func NewAuthClient(baseURL string, debug bool) (*AuthClient, error) { // armis:ignore cwe:522 reason:this code IS the credential protection check (HTTPS enforcement for non-localhost) // armis:ignore cwe:918 reason:baseURL is operator-controlled (ARMIS_API_URL) or the hardcoded RegionalBaseURL allowlist, never attacker-reachable input; this block IS the SSRF guard (rejects non-HTTPS non-localhost hosts) - if parsedURL.Scheme != "https" { + if parsedURL.Scheme != schemeHTTPS { host := parsedURL.Hostname() if host != "localhost" && host != "127.0.0.1" { return nil, fmt.Errorf("HTTPS required for non-localhost URLs") diff --git a/internal/auth/device.go b/internal/auth/device.go new file mode 100644 index 00000000..8297ca29 --- /dev/null +++ b/internal/auth/device.go @@ -0,0 +1,385 @@ +// Package auth provides authentication for the Armis API. +// This file implements the OAuth2 Device Authorization Grant (RFC 8628) client +// used by `armis-cli auth login`. The server side is the Moose OAuth2 +// authorization server (PPSC-1033), mounted at the issuer root. +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/ArmisSecurity/armis-cli/internal/httpclient" +) + +const ( + // DefaultDeviceClientID is the public client_id armis-cli identifies as in + // the device flow. The CLI is a public client (no secret); security comes + // from the device_code and refresh-token rotation, not this identifier. + DefaultDeviceClientID = "armis-cli" + + // Grant types (RFC 8628 §3.4 / RFC 6749 §6). + grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code" + grantTypeRefreshToken = "refresh_token" + + // deviceEndpointPath / tokenEndpointPath are root-mounted on the issuer per + // RFC 8628 / the backend router (api_controller/oauth2/router.py). + deviceEndpointPath = "/oauth2/device" + tokenEndpointPath = "/oauth2/token" // #nosec G101 -- URL path, not a credential + + // Polling guardrails so a misbehaving server cannot make us hammer it. + minPollInterval = 1 * time.Second + defaultPollInterval = 5 * time.Second + maxPollInterval = 60 * time.Second +) + +// DeviceAuthorization is the RFC 8628 §3.2 device authorization response. +type DeviceAuthorization struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// tokenResponse mirrors the backend TokenResponse (RFC 6749 §5.1). +type tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope,omitempty"` +} + +// oauthErrorResponse is the RFC 6749 §5.2 / RFC 8628 §3.5 error body. +type oauthErrorResponse struct { + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +// OAuthError is a typed OAuth2 protocol error so callers can branch on the code +// (e.g. authorization_pending vs. expired_token). +type OAuthError struct { + Code string + Description string + StatusCode int +} + +func (e *OAuthError) Error() string { + if e.Description != "" { + return fmt.Sprintf("%s: %s", e.Code, e.Description) + } + return e.Code +} + +// OAuth2 error codes we branch on (RFC 8628 §3.5). +const ( + errAuthorizationPending = "authorization_pending" + errSlowDown = "slow_down" + errExpiredToken = "expired_token" + errAccessDenied = "access_denied" + errInvalidGrant = "invalid_grant" +) + +// DeviceClient talks to the OAuth2 device + token endpoints on the issuer. +type DeviceClient struct { + baseURL string + httpClient *http.Client + debug bool +} + +// NewDeviceClient creates a device-flow client for the given issuer base URL. +// HTTPS is enforced for non-localhost hosts and redirects are disabled, matching +// the hardening of the client-credentials AuthClient. +func NewDeviceClient(baseURL string, debug bool) (*DeviceClient, error) { + if baseURL == "" { + return nil, fmt.Errorf("API base URL is required for device authentication") + } + + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + // armis:ignore cwe:918 reason:baseURL is operator-controlled (ARMIS_API_URL) or the hardcoded RegionalBaseURL allowlist; this block IS the SSRF guard (rejects non-HTTPS non-localhost hosts) + if parsedURL.Scheme != schemeHTTPS { + host := parsedURL.Hostname() + if host != "localhost" && host != "127.0.0.1" { + return nil, fmt.Errorf("HTTPS required for non-localhost URLs") + } + } + + return &DeviceClient{ + baseURL: strings.TrimSuffix(baseURL, "/"), + httpClient: &http.Client{ + Timeout: 30 * time.Second, + // Honor OS proxy config (WinINET/PAC), matching AuthClient. + Transport: httpclient.ProxyAwareTransport(), + // Never follow redirects: the token endpoint carries the device_code + // and refresh_token, which must not be replayed to a redirect target. + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + debug: debug, + }, nil +} + +// RequestDeviceCode performs the RFC 8628 §3.1 device authorization request. +// tenantID identifies which Armis tenant to authenticate against and is required +// by the authorization server. +func (c *DeviceClient) RequestDeviceCode(ctx context.Context, clientID, tenantID, scope string) (*DeviceAuthorization, error) { + if tenantID == "" { + return nil, fmt.Errorf("tenant_id is required to start the device authorization") + } + form := url.Values{} + form.Set("client_id", clientID) + form.Set("tenant_id", tenantID) + if scope != "" { + form.Set("scope", scope) + } + + body, status, err := c.postForm(ctx, deviceEndpointPath, form) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, c.parseError(body, status) + } + + var da DeviceAuthorization + if err := json.Unmarshal(body, &da); err != nil { + return nil, fmt.Errorf("failed to parse device authorization response: %w", err) + } + if da.DeviceCode == "" || da.UserCode == "" { + return nil, fmt.Errorf("device authorization response missing required fields") + } + return &da, nil +} + +// PollToken polls the token endpoint until the user approves, the device code +// expires, or the request is denied (RFC 8628 §3.4/§3.5). It honors the server's +// interval and backs off on slow_down. The provided context bounds the total +// wait (callers should set a deadline ~ the device code's expires_in). +func (c *DeviceClient) PollToken(ctx context.Context, deviceCode, clientID string, intervalSeconds int) (*StoredToken, error) { + interval := time.Duration(intervalSeconds) * time.Second + if interval < minPollInterval { + interval = defaultPollInterval + } + if interval > maxPollInterval { + interval = maxPollInterval + } + + for { + // Wait first: the spec requires waiting `interval` between polls, and the + // authorization is never approved instantly anyway. + select { + case <-ctx.Done(): + return nil, fmt.Errorf("timed out waiting for authorization: %w", ctx.Err()) + case <-time.After(interval): + } + + tok, err := c.exchangeDeviceCode(ctx, deviceCode, clientID) + if err == nil { + return tok, nil + } + + var oerr *OAuthError + if !asOAuthError(err, &oerr) { + return nil, err // transport / parse error — give up + } + switch oerr.Code { + case errAuthorizationPending: + continue + case errSlowDown: + interval += 5 * time.Second + if interval > maxPollInterval { + interval = maxPollInterval + } + continue + case errExpiredToken: + return nil, fmt.Errorf("the login request expired before it was approved; run 'armis-cli auth login' again") + case errAccessDenied: + return nil, fmt.Errorf("the login request was denied") + default: + return nil, oerr + } + } +} + +// exchangeDeviceCode does a single device_code token exchange. +func (c *DeviceClient) exchangeDeviceCode(ctx context.Context, deviceCode, clientID string) (*StoredToken, error) { + form := url.Values{} + form.Set("grant_type", grantTypeDeviceCode) + form.Set("device_code", deviceCode) + form.Set("client_id", clientID) + return c.tokenRequest(ctx, form, clientID) +} + +// Refresh exchanges a refresh token for a fresh access/refresh token pair +// (RFC 6749 §6). The backend rotates the refresh token, so the returned +// StoredToken carries a new RefreshToken that callers must persist. +func (c *DeviceClient) Refresh(ctx context.Context, refreshToken, clientID string) (*StoredToken, error) { + form := url.Values{} + form.Set("grant_type", grantTypeRefreshToken) + form.Set("refresh_token", refreshToken) + if clientID != "" { + form.Set("client_id", clientID) + } + return c.tokenRequest(ctx, form, clientID) +} + +// tokenRequest posts to the token endpoint and converts a success response into +// a StoredToken, deriving identity fields from the access-token claims. +func (c *DeviceClient) tokenRequest(ctx context.Context, form url.Values, clientID string) (*StoredToken, error) { + body, status, err := c.postForm(ctx, tokenEndpointPath, form) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, c.parseError(body, status) + } + + var tr tokenResponse + if err := json.Unmarshal(body, &tr); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + if tr.AccessToken == "" { + return nil, fmt.Errorf("token response missing access_token") + } + + claims, err := parseAccessTokenClaims(tr.AccessToken) + if err != nil { + return nil, fmt.Errorf("failed to parse access token: %w", err) + } + + // Prefer the server-provided expires_in; fall back to the token's exp claim. + expiresAt := claims.ExpiresAt + if tr.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(tr.ExpiresIn) * time.Second) + } + + return &StoredToken{ + AccessToken: tr.AccessToken, + RefreshToken: tr.RefreshToken, + ExpiresAt: expiresAt, + TenantID: claims.TenantID, + Subject: claims.Subject, + Role: claims.Role, + Issuer: claims.Issuer, + Region: claims.Region, + ClientID: clientID, + }, nil +} + +// postForm issues a form-encoded POST and returns the body and status code. +func (c *DeviceClient) postForm(ctx context.Context, path string, form url.Values) ([]byte, int, error) { + endpoint := c.baseURL + path + // armis:ignore cwe:918 reason:baseURL validated by NewDeviceClient (HTTPS enforced for non-localhost); path is a hardcoded constant + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, 0, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) //nolint:gosec // endpoint built from validated config, not user input + if err != nil { + return nil, 0, fmt.Errorf("request failed: %w", annotateTransportError(err)) + } + defer resp.Body.Close() //nolint:errcheck // response body read-only + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) + if err != nil { + return nil, 0, fmt.Errorf("failed to read response: %w", err) + } + return body, resp.StatusCode, nil +} + +// parseError converts an OAuth2 error body into a typed *OAuthError. When the +// body is not the expected JSON shape it falls back to a status-based message. +func (c *DeviceClient) parseError(body []byte, status int) error { + var oe oauthErrorResponse + if err := json.Unmarshal(body, &oe); err == nil && oe.ErrorCode != "" { + return &OAuthError{Code: oe.ErrorCode, Description: oe.ErrorDescription, StatusCode: status} + } + return &OAuthError{Code: "server_error", Description: fmt.Sprintf("unexpected response (status %d)", status), StatusCode: status} +} + +// asOAuthError is errors.As specialized for *OAuthError. +func asOAuthError(err error, target **OAuthError) bool { + for err != nil { + if oe, ok := err.(*OAuthError); ok { //nolint:errorlint // direct type assert is intentional here + *target = oe + return true + } + type unwrapper interface{ Unwrap() error } + u, ok := err.(unwrapper) + if !ok { + return false + } + err = u.Unwrap() + } + return false +} + +// accessTokenClaims are the Moose RS256 access-token claims (token_issuer.py). +// This is distinct from jwtClaims (client-credentials path), which reads the +// VIPR customer_id claim; the device-flow token uses tenant_id/sub/role. +type accessTokenClaims struct { + TenantID string + Subject string + Role string + Issuer string + Region string + ExpiresAt time.Time +} + +// parseAccessTokenClaims decodes (without verifying) the JWT payload. Signature +// verification is delegated to the backend, which validates every API request; +// the CLI only needs the claims for local display and refresh scheduling. +// +// armis:ignore cwe:287 reason:JWT signature verification delegated to server; CLI only extracts claims for caching/display +// armis:ignore cwe:327 reason:no cryptographic operations; base64-decodes JWT payload for claim extraction only +func parseAccessTokenClaims(token string) (*accessTokenClaims, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode JWT payload: %w", err) + } + + var data struct { + TenantID string `json:"tenant_id"` + Sub string `json:"sub"` + Role string `json:"role"` + Iss string `json:"iss"` + Region string `json:"region"` + Exp float64 `json:"exp"` // float64 tolerates fractional timestamps + } + if err := json.Unmarshal(payload, &data); err != nil { + return nil, fmt.Errorf("failed to parse JWT payload: %w", err) + } + + var expiresAt time.Time + if data.Exp > 0 { + expiresAt = time.Unix(int64(data.Exp), 0) + } + return &accessTokenClaims{ + TenantID: data.TenantID, + Subject: data.Sub, + Role: data.Role, + Issuer: data.Iss, + Region: data.Region, + ExpiresAt: expiresAt, + }, nil +} diff --git a/internal/auth/device_test.go b/internal/auth/device_test.go new file mode 100644 index 00000000..54cfb54b --- /dev/null +++ b/internal/auth/device_test.go @@ -0,0 +1,211 @@ +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +// makeDeviceJWT builds an unsigned JWT carrying the device-flow access-token +// claims (tenant_id/sub/role), distinct from the client-credentials customer_id. +func makeDeviceJWT(tenantID, sub, role string, exp int64) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + claims := map[string]any{ + "tenant_id": tenantID, + "sub": sub, + "role": role, + "iss": "https://moose.armis.com", + "exp": exp, + } + cj, _ := json.Marshal(claims) + payload := base64.RawURLEncoding.EncodeToString(cj) + sig := base64.RawURLEncoding.EncodeToString([]byte("sig")) + return header + "." + payload + "." + sig +} + +func TestRequestDeviceCode(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth2/device" { + w.WriteHeader(http.StatusNotFound) + return + } + _ = r.ParseForm() //nolint:gosec // G120: test server; request body is a tiny fixed form + if r.Form.Get("client_id") != "armis-cli" { + t.Errorf("client_id = %q", r.Form.Get("client_id")) + } + if r.Form.Get("tenant_id") != "tenant-1" { + t.Errorf("tenant_id = %q, want tenant-1", r.Form.Get("tenant_id")) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dev-code", + "user_code": "WDJB-MJHT", + "verification_uri": "https://moose.armis.com/oauth2/device/verify", + "verification_uri_complete": "https://moose.armis.com/oauth2/device/verify?user_code=WDJB-MJHT", + "expires_in": 900, + "interval": 5, + }) + })) + defer srv.Close() + + c, err := NewDeviceClient(srv.URL, false) + if err != nil { + t.Fatal(err) + } + da, err := c.RequestDeviceCode(context.Background(), "armis-cli", "tenant-1", "") + if err != nil { + t.Fatalf("RequestDeviceCode: %v", err) + } + if da.DeviceCode != "dev-code" || da.UserCode != "WDJB-MJHT" || da.Interval != 5 { + t.Errorf("unexpected device authorization: %+v", da) + } +} + +func TestRequestDeviceCodeRequiresTenant(t *testing.T) { + c, err := NewDeviceClient("https://moose.armis.com", false) + if err != nil { + t.Fatal(err) + } + if _, err := c.RequestDeviceCode(context.Background(), "armis-cli", "", ""); err == nil { + t.Fatal("expected error when tenant_id is empty") + } +} + +func TestPollTokenPendingThenSuccess(t *testing.T) { + var mu sync.Mutex + calls := 0 + jwt := makeDeviceJWT("tenant-1", "user@example.com", "admin", time.Now().Add(time.Hour).Unix()) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + calls++ + n := calls + mu.Unlock() + w.Header().Set("Content-Type", "application/json") + if n < 2 { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": errAuthorizationPending}) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": jwt, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refresh-1", + }) + })) + defer srv.Close() + + c, err := NewDeviceClient(srv.URL, false) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // interval=0 is clamped to the default; override to keep the test fast by + // using a 1s minimum via a tiny interval value. + tok, err := c.PollToken(ctx, "dev-code", "armis-cli", 1) + if err != nil { + t.Fatalf("PollToken: %v", err) + } + if tok.AccessToken != jwt || tok.RefreshToken != "refresh-1" { + t.Errorf("unexpected tokens: %+v", tok) + } + if tok.TenantID != "tenant-1" || tok.Subject != "user@example.com" || tok.Role != "admin" { + t.Errorf("claims not parsed: %+v", tok) + } +} + +func TestPollTokenAccessDenied(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": errAccessDenied}) + })) + defer srv.Close() + + c, _ := NewDeviceClient(srv.URL, false) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := c.PollToken(ctx, "dev-code", "armis-cli", 1) + if err == nil { + t.Fatal("expected denial error") + } +} + +func TestPollTokenExpired(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": errExpiredToken}) + })) + defer srv.Close() + + c, _ := NewDeviceClient(srv.URL, false) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := c.PollToken(ctx, "dev-code", "armis-cli", 1) + if err == nil { + t.Fatal("expected expiry error") + } +} + +func TestRefresh(t *testing.T) { + jwt := makeDeviceJWT("tenant-9", "svc", "viewer", time.Now().Add(time.Hour).Unix()) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() //nolint:gosec // G120: test server; request body is a tiny fixed form + if r.Form.Get("grant_type") != grantTypeRefreshToken { + t.Errorf("grant_type = %q", r.Form.Get("grant_type")) + } + if r.Form.Get("refresh_token") != "old-refresh" { + t.Errorf("refresh_token = %q", r.Form.Get("refresh_token")) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": jwt, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "new-refresh", + }) + })) + defer srv.Close() + + c, _ := NewDeviceClient(srv.URL, false) + tok, err := c.Refresh(context.Background(), "old-refresh", "armis-cli") + if err != nil { + t.Fatalf("Refresh: %v", err) + } + if tok.RefreshToken != "new-refresh" || tok.TenantID != "tenant-9" { + t.Errorf("unexpected refreshed token: %+v", tok) + } +} + +func TestNewDeviceClientRejectsHTTP(t *testing.T) { + if _, err := NewDeviceClient("http://moose.armis.com", false); err == nil { + t.Fatal("expected HTTPS enforcement error") + } + // localhost http is allowed (tests / dev). + if _, err := NewDeviceClient("http://localhost:8080", false); err != nil { + t.Errorf("localhost http should be allowed: %v", err) + } +} + +func TestParseAccessTokenClaims(t *testing.T) { + jwt := makeDeviceJWT("t", "s", "r", 1700000000) + claims, err := parseAccessTokenClaims(jwt) + if err != nil { + t.Fatalf("parse: %v", err) + } + if claims.TenantID != "t" || claims.Subject != "s" || claims.Role != "r" { + t.Errorf("unexpected claims: %+v", claims) + } + if claims.ExpiresAt.Unix() != 1700000000 { + t.Errorf("exp = %v", claims.ExpiresAt.Unix()) + } +} diff --git a/internal/auth/oauth_provider_test.go b/internal/auth/oauth_provider_test.go new file mode 100644 index 00000000..aaccec01 --- /dev/null +++ b/internal/auth/oauth_provider_test.go @@ -0,0 +1,115 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// newOAuthTestProvider wires an SSO-mode provider against a mock token endpoint +// and an in-memory token store. +func newOAuthTestProvider(t *testing.T, srvURL string, stored *StoredToken) (*AuthProvider, *TokenStore) { + t.Helper() + store := &TokenStore{dir: t.TempDir()} + if err := store.Save(srvURL, stored); err != nil { + t.Fatalf("seed store: %v", err) + } + dc, err := NewDeviceClient(srvURL, false) + if err != nil { + t.Fatalf("device client: %v", err) + } + p, err := NewProviderFromStored(store, dc, srvURL, stored) + if err != nil { + t.Fatalf("provider: %v", err) + } + return p, store +} + +func TestOAuthProviderUsesStoredTokenWithoutRefresh(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("token endpoint should not be called when token is fresh") + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + stored := sampleToken() + stored.AccessToken = makeDeviceJWT("tenant-1", "u", "admin", time.Now().Add(time.Hour).Unix()) + stored.ExpiresAt = time.Now().Add(time.Hour) + + p, _ := newOAuthTestProvider(t, srv.URL, stored) + + hdr, err := p.GetAuthorizationHeader(context.Background()) + if err != nil { + t.Fatalf("GetAuthorizationHeader: %v", err) + } + if hdr != "Bearer "+stored.AccessToken { + t.Errorf("unexpected header: %q", hdr) + } + if p.AuthMethod() != AuthMethodSSO { + t.Errorf("AuthMethod = %q, want sso", p.AuthMethod()) + } +} + +func TestOAuthProviderRefreshesNearExpiry(t *testing.T) { + newJWT := makeDeviceJWT("tenant-1", "u", "admin", time.Now().Add(time.Hour).Unix()) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() //nolint:gosec // G120: test server; request body is a tiny fixed form + if r.Form.Get("grant_type") != grantTypeRefreshToken { + t.Errorf("expected refresh grant, got %q", r.Form.Get("grant_type")) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": newJWT, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "rotated-refresh", + }) + })) + defer srv.Close() + + stored := sampleToken() + stored.AccessToken = "stale" + stored.ExpiresAt = time.Now().Add(1 * time.Minute) // within the 5-min window + + p, store := newOAuthTestProvider(t, srv.URL, stored) + + hdr, err := p.GetAuthorizationHeader(context.Background()) + if err != nil { + t.Fatalf("GetAuthorizationHeader: %v", err) + } + if hdr != "Bearer "+newJWT { + t.Errorf("expected refreshed token in header, got %q", hdr) + } + + // The rotated refresh token must be persisted for the next process. + reloaded, _ := store.Load(srv.URL) + if reloaded == nil || reloaded.RefreshToken != "rotated-refresh" { + t.Errorf("rotated refresh token not persisted: %+v", reloaded) + } +} + +func TestOAuthProviderRefreshFailureSurfacesReloginHint(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": errInvalidGrant}) + })) + defer srv.Close() + + stored := sampleToken() + stored.ExpiresAt = time.Now().Add(1 * time.Minute) + + p, _ := newOAuthTestProvider(t, srv.URL, stored) + + _, err := p.GetTenantID(context.Background()) + if err == nil { + t.Fatal("expected error on failed refresh") + } + if want := "auth login"; !strings.Contains(err.Error(), want) { + t.Errorf("error %q should mention %q", err.Error(), want) + } +} diff --git a/internal/auth/tokenstore.go b/internal/auth/tokenstore.go new file mode 100644 index 00000000..15af5942 --- /dev/null +++ b/internal/auth/tokenstore.go @@ -0,0 +1,266 @@ +// Package auth provides authentication for the Armis API. +// This file persists OAuth2 (device-flow) tokens so they survive across +// invocations and can be shared with other Armis tools (the MCP plugins). +package auth + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "time" +) + +// --------------------------------------------------------------------------- +// CROSS-PROCESS CONTRACT — DO NOT CHANGE THE PATH OR JSON SCHEMA CASUALLY. +// +// The token file and its JSON schema are a wire contract shared with other +// Armis developer tools (the armis-appsec / armis-knowledge MCP plugins, per +// epic PPSC-1032). Those tools read and write the SAME file so a single +// `armis-cli auth login` keeps every tool authenticated. +// +// A plain file (not the OS keychain) is the deliberate choice: the MCP plugins +// are Python, and the refresh-token rotation + reuse-detection on the backend +// requires a SINGLE source of truth (a divergent second store would replay a +// rotated token and get the whole token family revoked). The file is 0600 in a +// 0700 directory; protection at rest relies on the OS account + full-disk +// encryption (FileVault/BitLocker/LUKS), matching the AWS/gcloud/kubectl model. +// +// FILE SHAPE — a JSON array of per-environment entries, so one dev machine can +// hold tokens for several Armis environments at once (prod, dev, a local stack): +// +// [ +// {"env": "https://moose.armis.com", "token": { ...StoredToken... }}, +// {"env": "http://localhost:8001", "token": { ...StoredToken... }} +// ] +// +// `env` is the API base URL the token was obtained from (the lookup key). +// Python equivalent of the path: Path.home() / ".armis" / ".sessions". +// --------------------------------------------------------------------------- +const ( + // tokenStoreDirName is the per-user Armis config directory (~/.armis). + tokenStoreDirName = ".armis" + // tokenStoreFileName is the token file within that directory. + tokenStoreFileName = ".sessions" // #nosec G101 -- filename, not a credential + // tokenSchemaVersion versions the StoredToken JSON so future changes can be + // detected by older readers rather than mis-parsed. + tokenSchemaVersion = 1 + // maxTokenFileSize bounds reads to guard against a corrupted or maliciously + // large file exhausting memory. Generous to accommodate many environments. + maxTokenFileSize = 1 << 20 // 1MB +) + +// StoredToken is the persisted result of a device-flow login. Its JSON shape is +// the cross-process contract described above; add fields rather than renaming. +type StoredToken struct { + SchemaVersion int `json:"schema_version"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresAt time.Time `json:"expires_at"` + TenantID string `json:"tenant_id"` + Subject string `json:"subject"` + Role string `json:"role"` + Issuer string `json:"issuer,omitempty"` + Region string `json:"region,omitempty"` + ClientID string `json:"client_id,omitempty"` +} + +// tokenEntry is one environment's token within the file array. +type tokenEntry struct { + Env string `json:"env"` + Token *StoredToken `json:"token"` +} + +// TokenStore persists OAuth tokens to a 0600 file under ~/.armis, keyed by the +// environment (API base URL) each token belongs to. +type TokenStore struct { + // dir overrides the directory holding the token file (tests only). Empty + // means ~/.armis. + dir string +} + +// NewTokenStore returns a TokenStore backed by the per-user ~/.armis directory. +func NewTokenStore() *TokenStore { + return &TokenStore{} +} + +// normalizeEnv canonicalizes an environment key so trivially different spellings +// (a trailing slash, surrounding whitespace) resolve to the same entry. +func normalizeEnv(env string) string { + return strings.TrimRight(strings.TrimSpace(env), "/") +} + +// Save inserts or replaces the token for the given environment. +func (s *TokenStore) Save(env string, tok *StoredToken) error { + if tok == nil { + return errors.New("nil token") + } + if env == "" { + return errors.New("env is required to store a token") + } + tok.SchemaVersion = tokenSchemaVersion + env = normalizeEnv(env) + + entries, err := s.read() + if err != nil { + return err + } + + replaced := false + for i := range entries { + if normalizeEnv(entries[i].Env) == env { + entries[i].Token = tok + replaced = true + break + } + } + if !replaced { + entries = append(entries, tokenEntry{Env: env, Token: tok}) + } + return s.write(entries) +} + +// Load returns the stored token for the given environment, or (nil, nil) when +// none is present. A corrupted or oversized file is treated as "no token" so a +// bad file never breaks credential resolution — callers fall through to env vars. +func (s *TokenStore) Load(env string) (*StoredToken, error) { + env = normalizeEnv(env) + entries, err := s.read() + if err != nil { + return nil, nil //nolint:nilerr // unreadable/corrupted file treated as absent + } + for i := range entries { + if normalizeEnv(entries[i].Env) == env { + tok := entries[i].Token + if tok == nil || (tok.AccessToken == "" && tok.RefreshToken == "") { + return nil, nil + } + return tok, nil + } + } + return nil, nil +} + +// Clear removes the token for the given environment. It is idempotent. When the +// last entry is removed the file itself is deleted. +func (s *TokenStore) Clear(env string) error { + env = normalizeEnv(env) + entries, err := s.read() + if err != nil { + return nil //nolint:nilerr // nothing usable to clear + } + kept := entries[:0] + for _, e := range entries { + if normalizeEnv(e.Env) != env { + kept = append(kept, e) + } + } + if len(kept) == 0 { + return s.remove() + } + return s.write(kept) +} + +// ClearAll removes every stored token by deleting the file. +func (s *TokenStore) ClearAll() error { + return s.remove() +} + +// Environments lists the environments that currently have a stored token. +func (s *TokenStore) Environments() []string { + entries, err := s.read() + if err != nil { + return nil + } + envs := make([]string, 0, len(entries)) + for _, e := range entries { + envs = append(envs, e.Env) + } + return envs +} + +// Path returns the resolved token-file path (for diagnostics / logout output). +func (s *TokenStore) Path() string { + path, _ := s.filePath() + return path +} + +// read loads and parses the token file. A missing file yields an empty slice; +// a corrupted/oversized file yields an error so callers can decide how to react +// (Load/Clear treat it as absent rather than failing the CLI). +func (s *TokenStore) read() ([]tokenEntry, error) { + path, err := s.filePath() + if err != nil { + return nil, err + } + // armis:ignore cwe:367 reason:stat-then-read race is benign; worst case reads a stale token, no security impact + info, statErr := os.Stat(path) + if statErr != nil { + if os.IsNotExist(statErr) { + return nil, nil + } + return nil, statErr + } + if info.Size() > maxTokenFileSize { + return nil, fmt.Errorf("token file %s exceeds %d bytes", path, maxTokenFileSize) + } + data, err := os.ReadFile(path) //nolint:gosec // path derived from os.UserHomeDir + hardcoded segments + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, nil + } + var entries []tokenEntry + if err := json.Unmarshal(data, &entries); err != nil { + return nil, fmt.Errorf("token file is not valid JSON: %w", err) + } + return entries, nil +} + +// write persists the entries to the 0600 file, creating ~/.armis (0700) if needed. +func (s *TokenStore) write(entries []tokenEntry) error { + path, err := s.filePath() + if err != nil { + return err + } + data, err := json.MarshalIndent(entries, "", " ") //nolint:gosec // G117: persisting the token blob to its file IS the purpose of this store + if err != nil { + return fmt.Errorf("failed to marshal tokens: %w", err) + } + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return fmt.Errorf("failed to create token directory: %w", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { //nolint:gosec // path derived from os.UserHomeDir + hardcoded segments + return fmt.Errorf("failed to write token file: %w", err) + } + return nil +} + +func (s *TokenStore) remove() error { + path, err := s.filePath() + if err != nil { + return nil //nolint:nilerr // nothing to remove if no path resolves + } + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +// filePath resolves the token file path: /.sessions, where dir is +// the test override or ~/.armis. +func (s *TokenStore) filePath() (string, error) { + dir := s.dir + if dir == "" { + // armis:ignore cwe:22 reason:os.UserHomeDir is a trusted OS source; joined with hardcoded path segments + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("cannot determine home directory: %w", err) + } + dir = filepath.Join(home, tokenStoreDirName) + } + return filepath.Join(dir, tokenStoreFileName), nil +} diff --git a/internal/auth/tokenstore_test.go b/internal/auth/tokenstore_test.go new file mode 100644 index 00000000..61050246 --- /dev/null +++ b/internal/auth/tokenstore_test.go @@ -0,0 +1,193 @@ +package auth + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func sampleToken() *StoredToken { + return &StoredToken{ + AccessToken: "access-abc", + RefreshToken: "refresh-xyz", + ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), + TenantID: "tenant-1", + Subject: "user@example.com", + Role: "admin", + Issuer: "https://moose.armis.com", + Region: "us1", + ClientID: "armis-cli", + } +} + +const ( + envProd = "https://moose.armis.com" + envDev = "http://localhost:8001" +) + +func TestTokenStoreRoundTrip(t *testing.T) { + dir := t.TempDir() + store := &TokenStore{dir: dir} + + want := sampleToken() + if err := store.Save(envProd, want); err != nil { + t.Fatalf("Save: %v", err) + } + + // File should exist with 0600 perms in a 0700 dir. + path := filepath.Join(dir, tokenStoreFileName) + info, err := os.Stat(path) + if err != nil { + t.Fatalf("expected token file: %v", err) + } + if perm := info.Mode().Perm(); perm != 0o600 { + t.Errorf("token file perm = %o, want 600", perm) + } + + got, err := store.Load(envProd) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got == nil { + t.Fatal("Load returned nil token") + } + if got.AccessToken != want.AccessToken || got.RefreshToken != want.RefreshToken || + got.TenantID != want.TenantID || got.Subject != want.Subject { + t.Errorf("round-trip mismatch: got %+v want %+v", got, want) + } + if got.SchemaVersion != tokenSchemaVersion { + t.Errorf("SchemaVersion = %d, want %d", got.SchemaVersion, tokenSchemaVersion) + } +} + +// TestTokenStoreMultipleEnvironments is the core of the env-scoped design: two +// environments coexist, are read back independently, and clearing one leaves +// the other intact. +func TestTokenStoreMultipleEnvironments(t *testing.T) { + store := &TokenStore{dir: t.TempDir()} + + prod := sampleToken() + prod.TenantID = "tenant-prod" + dev := sampleToken() + dev.TenantID = "tenant-dev" + dev.AccessToken = "dev-access" + + if err := store.Save(envProd, prod); err != nil { + t.Fatalf("Save prod: %v", err) + } + if err := store.Save(envDev, dev); err != nil { + t.Fatalf("Save dev: %v", err) + } + + gotProd, _ := store.Load(envProd) + gotDev, _ := store.Load(envDev) + if gotProd == nil || gotProd.TenantID != "tenant-prod" { + t.Errorf("prod token wrong: %+v", gotProd) + } + if gotDev == nil || gotDev.TenantID != "tenant-dev" || gotDev.AccessToken != "dev-access" { + t.Errorf("dev token wrong: %+v", gotDev) + } + + if envs := store.Environments(); len(envs) != 2 { + t.Errorf("Environments() = %v, want 2 entries", envs) + } + + // Clearing dev must not disturb prod. + if err := store.Clear(envDev); err != nil { + t.Fatalf("Clear dev: %v", err) + } + if got, _ := store.Load(envDev); got != nil { + t.Errorf("dev should be cleared, got %+v", got) + } + if got, _ := store.Load(envProd); got == nil { + t.Error("prod should survive clearing dev") + } +} + +// TestTokenStoreEnvNormalization: a trailing slash must resolve to the same entry. +func TestTokenStoreEnvNormalization(t *testing.T) { + store := &TokenStore{dir: t.TempDir()} + if err := store.Save(envProd+"/", sampleToken()); err != nil { + t.Fatalf("Save: %v", err) + } + got, _ := store.Load(envProd) + if got == nil { + t.Error("trailing-slash env should resolve to the same entry") + } + // Saving the same env (no slash) replaces, not duplicates. + if err := store.Save(envProd, sampleToken()); err != nil { + t.Fatalf("Save again: %v", err) + } + if envs := store.Environments(); len(envs) != 1 { + t.Errorf("expected 1 entry after re-save, got %v", envs) + } +} + +func TestTokenStoreSaveReplaces(t *testing.T) { + store := &TokenStore{dir: t.TempDir()} + first := sampleToken() + first.AccessToken = "first" + second := sampleToken() + second.AccessToken = "second" + + _ = store.Save(envProd, first) + _ = store.Save(envProd, second) + + got, _ := store.Load(envProd) + if got == nil || got.AccessToken != "second" { + t.Errorf("expected replacement, got %+v", got) + } + if envs := store.Environments(); len(envs) != 1 { + t.Errorf("expected 1 entry, got %v", envs) + } +} + +func TestTokenStoreLoadEmpty(t *testing.T) { + store := &TokenStore{dir: t.TempDir()} + got, err := store.Load(envProd) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got != nil { + t.Errorf("expected nil token, got %+v", got) + } +} + +func TestTokenStoreClearLastRemovesFile(t *testing.T) { + dir := t.TempDir() + store := &TokenStore{dir: dir} + if err := store.Save(envProd, sampleToken()); err != nil { + t.Fatalf("Save: %v", err) + } + if err := store.Clear(envProd); err != nil { + t.Fatalf("Clear: %v", err) + } + got, _ := store.Load(envProd) + if got != nil { + t.Errorf("expected nil after Clear, got %+v", got) + } + // The file should be gone once the last entry is cleared. + if _, err := os.Stat(filepath.Join(dir, tokenStoreFileName)); !os.IsNotExist(err) { + t.Errorf("expected file removed after clearing last entry, stat err = %v", err) + } + // Clear is idempotent. + if err := store.Clear(envProd); err != nil { + t.Errorf("second Clear errored: %v", err) + } +} + +func TestTokenStoreCorruptedFileTreatedAsAbsent(t *testing.T) { + dir := t.TempDir() + store := &TokenStore{dir: dir} + if err := os.WriteFile(filepath.Join(dir, tokenStoreFileName), []byte("{not json"), 0o600); err != nil { + t.Fatal(err) + } + got, err := store.Load(envProd) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got != nil { + t.Errorf("expected nil for corrupted file, got %+v", got) + } +} diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go index a90c8a54..4c19dbc0 100644 --- a/internal/cmd/auth.go +++ b/internal/cmd/auth.go @@ -8,43 +8,54 @@ import ( "github.com/spf13/cobra" ) +// authCmd is the parent group for authentication commands. var authCmd = &cobra.Command{ Use: "auth", - Short: "Authenticate and print JWT token", - Long: `Exchange client credentials for a JWT token and print it to stdout. + Short: "Manage authentication with Armis Cloud", + Long: `Authenticate the CLI with Armis Cloud. + +The recommended path is browser-based SSO: + + armis-cli auth login Sign in via your browser (OAuth2 Device Authorization) + armis-cli auth whoami Show the current identity, tenant, and token expiry + armis-cli auth logout Remove stored credentials + +For CI/CD and service accounts, set ARMIS_CLIENT_ID / ARMIS_CLIENT_SECRET +(client-credentials) instead of logging in interactively.`, +} + +// authTokenCmd preserves the original `armis-cli auth` behavior (print a raw +// JWT obtained via client credentials) as a hidden `auth token` subcommand, for +// testing auth configuration and piping tokens to other tools. +var authTokenCmd = &cobra.Command{ + Use: "token", + Short: "Print a JWT access token to stdout", + Long: `Exchange credentials for an access token and print it to stdout. + +Uses the active credentials in resolution order: a stored SSO session +(from 'armis-cli auth login'), then client credentials (--client-id / +--client-secret or ARMIS_CLIENT_ID / ARMIS_CLIENT_SECRET). This command is useful for: - Testing authentication configuration - Obtaining tokens for use with other tools -- Debugging JWT-related issues - -Requires --client-id and --client-secret flags or their corresponding -environment variables (ARMIS_CLIENT_ID, ARMIS_CLIENT_SECRET).`, - Example: ` # Obtain JWT token using flags - armis-cli auth --client-id MY_ID --client-secret MY_SECRET - - # Obtain token using environment variables +- Debugging token-related issues`, + Example: ` # Print a token using environment variables export ARMIS_CLIENT_ID=MY_ID export ARMIS_CLIENT_SECRET=MY_SECRET - armis-cli auth`, + armis-cli auth token`, RunE: runAuth, } func init() { - // Hide auth command until backend JWT support is available - authCmd.Hidden = true + // `auth token` stays hidden: it's a debug/scripting helper that prints a raw + // token, not part of the user-facing login/logout/whoami surface. + authTokenCmd.Hidden = true + authCmd.AddCommand(authTokenCmd) rootCmd.AddCommand(authCmd) } -func runAuth(cmd *cobra.Command, args []string) error { - // Validate required flags for JWT auth - if clientID == "" { - return fmt.Errorf("--client-id is required (or set ARMIS_CLIENT_ID)") - } - if clientSecret == "" { - return fmt.Errorf("--client-secret is required (or set ARMIS_CLIENT_SECRET)") - } - +func runAuth(cmd *cobra.Command, _ []string) error { provider, err := getAuthProvider() if err != nil { return fmt.Errorf("authentication failed: %w", err) @@ -58,8 +69,8 @@ func runAuth(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to get token: %w", err) } - // Print the raw token without any prefix (useful for piping to other tools) - // armis:ignore cwe:522 reason:auth command's purpose is to output the token for piping to other tools + // Print the raw token without any prefix (useful for piping to other tools). + // armis:ignore cwe:522 reason:auth token command's purpose is to output the token for piping to other tools fmt.Println(token) return nil } diff --git a/internal/cmd/auth_flow_test.go b/internal/cmd/auth_flow_test.go new file mode 100644 index 00000000..ba36ae9b --- /dev/null +++ b/internal/cmd/auth_flow_test.go @@ -0,0 +1,271 @@ +package cmd + +import ( + "context" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/ArmisSecurity/armis-cli/internal/auth" + "github.com/spf13/cobra" +) + +// tenant7 is the tenant ID used across the login-flow tests. +const tenant7 = "tenant-7" + +// deviceJWT builds an unsigned access token with the device-flow claims. +func deviceJWT(tenant, sub, role string, exp int64) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + claims, _ := json.Marshal(map[string]any{ + "tenant_id": tenant, "sub": sub, "role": role, + "iss": "https://moose.armis.com", "exp": exp, + }) + return header + "." + base64.RawURLEncoding.EncodeToString(claims) + "." + + base64.RawURLEncoding.EncodeToString([]byte("sig")) +} + +// setupAuthTest isolates global state: it points ARMIS_API_URL at the mock +// server, redirects HOME to a temp dir so the token store writes there instead +// of the real ~/.armis, and clears credential globals. The browser opener is +// stubbed once in TestMain. Everything is restored via t.Cleanup / t.Setenv. +func setupAuthTest(t *testing.T, serverURL string) { + t.Helper() + t.Setenv("HOME", t.TempDir()) // token store resolves ~/.armis from HOME + + origGlobals := struct{ clientID, clientSecret, token, tenantID string }{ + clientID, clientSecret, token, tenantID, + } + t.Cleanup(func() { + clientID = origGlobals.clientID + clientSecret = origGlobals.clientSecret + token = origGlobals.token + tenantID = origGlobals.tenantID + credFlagsExplicit = false + noProgress = false + loginOrg = "" + loginClientID = auth.DefaultDeviceClientID + logoutAll = false + }) + clientID, clientSecret, token, tenantID = "", "", "", "" + credFlagsExplicit = false + noProgress = true + loginClientID = auth.DefaultDeviceClientID + + t.Setenv("ARMIS_API_URL", serverURL) +} + +func newCmdWithCtx() *cobra.Command { + c := &cobra.Command{} + c.SetContext(context.Background()) + return c +} + +func TestAuthLoginStoresTokens(t *testing.T) { + jwt := deviceJWT(tenant7, "alice@example.com", "admin", time.Now().Add(time.Hour).Unix()) + var mu sync.Mutex + tokenCalls := 0 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/oauth2/device": + _ = r.ParseForm() //nolint:gosec // G120: test server; request body is a tiny fixed form + if r.Form.Get("tenant_id") != tenant7 { + t.Errorf("device request tenant_id = %q, want %s", r.Form.Get("tenant_id"), tenant7) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dev-code", + "user_code": "WDJB-MJHT", + "verification_uri": "https://moose.armis.com/oauth2/device/verify", + "verification_uri_complete": "https://moose.armis.com/oauth2/device/verify?user_code=WDJB-MJHT", + "expires_in": 900, + "interval": 1, + }) + case "/oauth2/token": + mu.Lock() + tokenCalls++ + n := tokenCalls + mu.Unlock() + if n < 2 { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": jwt, "token_type": "Bearer", + "expires_in": 3600, "refresh_token": "refresh-7", + }) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + setupAuthTest(t, srv.URL) + tenantID = tenant7 // device authorization requires a tenant + + if err := runAuthLogin(newCmdWithCtx(), nil); err != nil { + t.Fatalf("runAuthLogin: %v", err) + } + + stored, err := auth.NewTokenStore().Load(srv.URL) + if err != nil || stored == nil { + t.Fatalf("expected stored token, err=%v stored=%v", err, stored) + } + if stored.RefreshToken != "refresh-7" || stored.TenantID != tenant7 { + t.Errorf("unexpected stored token: %+v", stored) + } + if stored.Subject != "alice@example.com" { + t.Errorf("subject = %q", stored.Subject) + } +} + +func TestAuthLoginRequiresTenant(t *testing.T) { + setupAuthTest(t, "https://moose.armis.com") + tenantID = "" // no --tenant-id / ARMIS_TENANT_ID + + err := runAuthLogin(newCmdWithCtx(), nil) + if err == nil { + t.Fatal("expected error when tenant ID is missing") + } + if !strings.Contains(err.Error(), "tenant ID required") { + t.Errorf("error %q should mention 'tenant ID required'", err.Error()) + } +} + +func TestAuthWhoamiAfterLogin(t *testing.T) { + jwt := deviceJWT(tenant7, "alice@example.com", "admin", time.Now().Add(2*time.Hour).Unix()) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) // no refresh expected; token is fresh + })) + defer srv.Close() + setupAuthTest(t, srv.URL) + + // Seed a stored token directly. + store := auth.NewTokenStore() + if err := store.Save(srv.URL, &auth.StoredToken{ + AccessToken: jwt, RefreshToken: "r", ExpiresAt: time.Now().Add(2 * time.Hour), + TenantID: tenant7, Subject: "alice@example.com", Role: "admin", + Issuer: srv.URL, ClientID: "armis-cli", + }); err != nil { + t.Fatal(err) + } + + if err := runAuthWhoami(newCmdWithCtx(), nil); err != nil { + t.Fatalf("runAuthWhoami: %v", err) + } +} + +func TestAuthLogout(t *testing.T) { + const env = "https://moose.armis.com" + setupAuthTest(t, env) + store := auth.NewTokenStore() + if err := store.Save(env, &auth.StoredToken{AccessToken: "a", RefreshToken: "r", TenantID: "t"}); err != nil { + t.Fatal(err) + } + + if err := runAuthLogout(newCmdWithCtx(), nil); err != nil { + t.Fatalf("runAuthLogout: %v", err) + } + got, _ := store.Load(env) + if got != nil { + t.Errorf("expected token cleared, got %+v", got) + } + + // Idempotent: logging out again is not an error. + if err := runAuthLogout(newCmdWithCtx(), nil); err != nil { + t.Errorf("second logout errored: %v", err) + } +} + +// TestAuthLogoutScoping: a plain logout removes only the current environment's +// token; --all removes them all. +func TestAuthLogoutScoping(t *testing.T) { + const curEnv = "https://moose.armis.com" + const otherEnv = "http://localhost:8001" + setupAuthTest(t, curEnv) // current env resolves to curEnv via ARMIS_API_URL + + store := auth.NewTokenStore() + for _, e := range []string{curEnv, otherEnv} { + if err := store.Save(e, &auth.StoredToken{AccessToken: "a", RefreshToken: "r", TenantID: "t"}); err != nil { + t.Fatal(err) + } + } + + // Plain logout clears only curEnv. + logoutAll = false + if err := runAuthLogout(newCmdWithCtx(), nil); err != nil { + t.Fatalf("logout: %v", err) + } + if got, _ := store.Load(curEnv); got != nil { + t.Error("current env token should be cleared") + } + if got, _ := store.Load(otherEnv); got == nil { + t.Error("other env token should survive a scoped logout") + } + + // --all clears everything. + logoutAll = true + if err := runAuthLogout(newCmdWithCtx(), nil); err != nil { + t.Fatalf("logout --all: %v", err) + } + if got, _ := store.Load(otherEnv); got != nil { + t.Error("--all should clear remaining tokens") + } +} + +// TestStoredTokenTakesPrecedence verifies the resolution order: a stored SSO +// token is used over env-var client credentials when no credential flags are set. +func TestStoredTokenTakesPrecedence(t *testing.T) { + jwt := deviceJWT(tenant7, "alice@example.com", "admin", time.Now().Add(time.Hour).Unix()) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + setupAuthTest(t, srv.URL) + + store := auth.NewTokenStore() + if err := store.Save(srv.URL, &auth.StoredToken{ + AccessToken: jwt, RefreshToken: "r", ExpiresAt: time.Now().Add(time.Hour), + TenantID: tenant7, Subject: "alice@example.com", Issuer: srv.URL, + }); err != nil { + t.Fatal(err) + } + + // Env client credentials are present but should be ignored in favor of SSO. + clientID = "env-client" + clientSecret = "env-secret" + + provider, err := getAuthProvider() + if err != nil { + t.Fatalf("getAuthProvider: %v", err) + } + if provider.AuthMethod() != auth.AuthMethodSSO { + t.Errorf("AuthMethod = %q, want sso", provider.AuthMethod()) + } +} + +func TestNoCredentialsErrorMentionsLogin(t *testing.T) { + setupAuthTest(t, "https://moose.armis.com") + // No stored token, no credentials. + _, err := getAuthProvider() + if err == nil { + t.Fatal("expected error with no credentials") + } + if !strings.Contains(err.Error(), "auth login") { + t.Errorf("error %q should mention 'auth login'", err.Error()) + } +} + +func TestMain(m *testing.M) { + // Ensure no test accidentally spawns a real browser: stub the opener to a + // no-op (success) for the whole cmd test binary. + auth.SetBrowserOpener(func(string) error { return nil }) + os.Exit(m.Run()) +} diff --git a/internal/cmd/auth_login.go b/internal/cmd/auth_login.go new file mode 100644 index 00000000..992681c1 --- /dev/null +++ b/internal/cmd/auth_login.go @@ -0,0 +1,141 @@ +package cmd + +import ( + "context" + "fmt" + "net/url" + "os" + "time" + + "github.com/ArmisSecurity/armis-cli/internal/auth" + "github.com/ArmisSecurity/armis-cli/internal/output" + "github.com/ArmisSecurity/armis-cli/internal/progress" + "github.com/spf13/cobra" +) + +var ( + loginOrg string + loginClientID string +) + +var authLoginCmd = &cobra.Command{ + Use: "login", + Short: "Sign in to Armis Cloud via your browser", + Long: `Authenticate with Armis Cloud using browser-based SSO (OAuth2 Device +Authorization Grant, RFC 8628). + +The CLI requests a device code, opens your browser to the Armis sign-in page, +and waits while you authenticate with your corporate identity provider. On +success the access and refresh tokens are stored in a 0600 file under ~/.armis +and shared with the Armis MCP plugins. + +A tenant is required: pass --tenant-id or set ARMIS_TENANT_ID. + +If the browser cannot be opened automatically (for example over SSH), the CLI +prints a URL and a code to enter manually.`, + Example: ` # Sign in interactively + armis-cli auth login --tenant-id my-tenant + + # Skip org selection in the browser + armis-cli auth login --tenant-id my-tenant --org my-company`, + Args: cobra.NoArgs, + RunE: runAuthLogin, +} + +func init() { + authLoginCmd.Flags().StringVar(&loginOrg, "org", "", "Organization slug hint to skip org selection in the browser") + authLoginCmd.Flags().StringVar(&loginClientID, "client-id", auth.DefaultDeviceClientID, "OAuth2 client ID to authenticate as") + authCmd.AddCommand(authLoginCmd) +} + +func runAuthLogin(cmd *cobra.Command, _ []string) error { + if tenantID == "" { + return fmt.Errorf("tenant ID required: use --tenant-id flag or ARMIS_TENANT_ID environment variable") + } + + issuer := getAPIBaseURL() + deviceClient, err := auth.NewDeviceClient(issuer, debug) + if err != nil { + return fmt.Errorf("failed to initialize login: %w", err) + } + + // Step 1: request a device code. Use a short timeout for this single call. + reqCtx, cancelReq := context.WithTimeout(cmd.Context(), 30*time.Second) + da, err := deviceClient.RequestDeviceCode(reqCtx, loginClientID, tenantID, "") + cancelReq() + if err != nil { + return fmt.Errorf("failed to start login: %w", err) + } + + // Step 2: send the user to the verification page. The browser URL carries + // the user_code, so the happy path needs no manual entry. --org is appended + // as a hint for the verification page to preselect the organization. + browseURL := withOrgHint(da.VerificationURIComplete, loginOrg) + opened := auth.OpenBrowser(browseURL) == nil + printVerificationInstructions(da, browseURL, opened) + + // Step 3: poll until approval, expiry, or denial. Bound the wait by the + // device code's lifetime. + pollCtx, cancelPoll := context.WithTimeout(cmd.Context(), time.Duration(da.ExpiresIn)*time.Second) + defer cancelPoll() + + spinner := progress.NewSpinner("Waiting for you to finish signing in…", noProgress) + spinner.Start() + stored, err := deviceClient.PollToken(pollCtx, da.DeviceCode, loginClientID, da.Interval) + spinner.Stop() + if err != nil { + return err + } + + // Step 4: persist the tokens for reuse by the CLI and MCP plugins, keyed by + // this environment (the API base URL) so multiple environments coexist. + stored.Issuer = issuer + store := auth.NewTokenStore() + if err := store.Save(issuer, stored); err != nil { + return fmt.Errorf("signed in, but failed to store credentials: %w", err) + } + + printLoginSuccess(stored) + return nil +} + +// withOrgHint appends an `org` query parameter to the verification URL when an +// org slug was supplied. A parse failure returns the URL unchanged. +func withOrgHint(rawURL, org string) string { + if org == "" { + return rawURL + } + u, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + q := u.Query() + q.Set("org", org) + u.RawQuery = q.Encode() + return u.String() +} + +// printVerificationInstructions tells the user where to authenticate, covering +// both the auto-opened-browser case and the manual fallback. +func printVerificationInstructions(da *auth.DeviceAuthorization, browseURL string, opened bool) { + if opened { + fmt.Fprintf(os.Stderr, "Opened your browser to complete sign-in.\n") + fmt.Fprintf(os.Stderr, "If it didn't open, visit:\n\n %s\n\n", browseURL) + fmt.Fprintf(os.Stderr, "Verify this code is shown: %s\n\n", output.GetStyles().Bold.Render(da.UserCode)) + return + } + fmt.Fprintf(os.Stderr, "To sign in, open the following URL in your browser:\n\n") + fmt.Fprintf(os.Stderr, " %s\n\n", da.VerificationURI) + fmt.Fprintf(os.Stderr, "and enter this code: %s\n\n", output.GetStyles().Bold.Render(da.UserCode)) +} + +// printLoginSuccess confirms the signed-in identity and tenant. +func printLoginSuccess(stored *auth.StoredToken) { + fmt.Fprintf(os.Stderr, "%s Signed in successfully.\n", output.IconSuccess) + if stored.Subject != "" { + fmt.Fprintf(os.Stderr, " Identity: %s\n", stored.Subject) + } + if stored.TenantID != "" { + fmt.Fprintf(os.Stderr, " Tenant: %s\n", stored.TenantID) + } +} diff --git a/internal/cmd/auth_logout.go b/internal/cmd/auth_logout.go new file mode 100644 index 00000000..2faa32ed --- /dev/null +++ b/internal/cmd/auth_logout.go @@ -0,0 +1,64 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/ArmisSecurity/armis-cli/internal/auth" + "github.com/ArmisSecurity/armis-cli/internal/output" + "github.com/spf13/cobra" +) + +var logoutAll bool + +var authLogoutCmd = &cobra.Command{ + Use: "logout", + Short: "Remove stored Armis credentials", + Long: `Remove the SSO tokens stored by 'armis-cli auth login'. + +By default only the token for the current environment is removed (the +environment is the resolved API base URL, e.g. set via --dev / --region / +ARMIS_API_URL). Use --all to remove tokens for every environment. + +This does not affect credentials supplied via environment variables or flags.`, + Args: cobra.NoArgs, + RunE: runAuthLogout, +} + +func init() { + authLogoutCmd.Flags().BoolVar(&logoutAll, "all", false, "Remove stored tokens for all environments") + authCmd.AddCommand(authLogoutCmd) +} + +func runAuthLogout(_ *cobra.Command, _ []string) error { + store := auth.NewTokenStore() + + if logoutAll { + envs := store.Environments() + if err := store.ClearAll(); err != nil { + return fmt.Errorf("failed to remove stored credentials: %w", err) + } + if len(envs) == 0 { + fmt.Fprintln(os.Stderr, "No stored credentials to remove.") + return nil + } + fmt.Fprintf(os.Stderr, "%s Signed out of %d environment(s).\n", output.IconSuccess, len(envs)) + return nil + } + + env := getAPIBaseURL() + + // Report whether anything was actually stored, so logout is informative + // rather than silently succeeding when not logged in. + existing, _ := store.Load(env) + if err := store.Clear(env); err != nil { + return fmt.Errorf("failed to remove stored credentials: %w", err) + } + + if existing == nil { + fmt.Fprintf(os.Stderr, "No stored credentials to remove for %s.\n", env) + return nil + } + fmt.Fprintf(os.Stderr, "%s Signed out of %s.\n", output.IconSuccess, env) + return nil +} diff --git a/internal/cmd/auth_test.go b/internal/cmd/auth_test.go index 2aa0c8be..74a74c8a 100644 --- a/internal/cmd/auth_test.go +++ b/internal/cmd/auth_test.go @@ -67,7 +67,7 @@ func TestRunAuth(t *testing.T) { clientSecret: "test-secret", setupServer: true, wantErr: true, - errContains: "--client-id is required", + errContains: "both --client-id and --client-secret must be provided", }, { name: "missing client-secret", @@ -75,7 +75,7 @@ func TestRunAuth(t *testing.T) { clientSecret: "", setupServer: true, wantErr: true, - errContains: "--client-secret is required", + errContains: "both --client-id and --client-secret must be provided", }, { name: "successful authentication", @@ -117,6 +117,10 @@ func TestRunAuth(t *testing.T) { } }) + // Redirect HOME to a temp dir so no real stored SSO token (~/.armis) + // short-circuits the credential resolution this test exercises. + t.Setenv("HOME", t.TempDir()) + // Clear legacy auth vars to ensure JWT path is taken token = "" tenantID = "" @@ -197,6 +201,8 @@ func TestRunAuth_InvalidEndpoint(t *testing.T) { } }) + t.Setenv("HOME", t.TempDir()) + // Clear legacy auth vars token = "" tenantID = "" diff --git a/internal/cmd/auth_whoami.go b/internal/cmd/auth_whoami.go new file mode 100644 index 00000000..169c6533 --- /dev/null +++ b/internal/cmd/auth_whoami.go @@ -0,0 +1,91 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/ArmisSecurity/armis-cli/internal/auth" + "github.com/spf13/cobra" +) + +var authWhoamiCmd = &cobra.Command{ + Use: "whoami", + Short: "Show the current Armis identity and auth method", + Long: `Display the identity, tenant, token expiry, and authentication method for +the currently active credentials, resolved in the same order the scan commands +use: stored SSO session, then client credentials, then legacy token.`, + Args: cobra.NoArgs, + RunE: runAuthWhoami, +} + +func init() { + authCmd.AddCommand(authWhoamiCmd) +} + +func runAuthWhoami(cmd *cobra.Command, _ []string) error { + provider, err := getAuthProvider() + if err != nil { + // getAuthProvider already returns a self-describing message (the + // no-credentials case lists the sign-in options); don't re-wrap it. + return err + } + + ctx, cancel := context.WithTimeout(cmd.Context(), 30*time.Second) + defer cancel() + + // Resolve the tenant; this also triggers a refresh if the token is stale, + // surfacing an expired-session error here rather than on the next scan. + tenant, err := provider.GetTenantID(ctx) + if err != nil { + return err + } + + fmt.Fprintf(os.Stderr, "Environment: %s\n", getAPIBaseURL()) + fmt.Fprintf(os.Stderr, "Auth method: %s\n", describeAuthMethod(provider.AuthMethod())) + if id := provider.Identity(); id != "" { + fmt.Fprintf(os.Stderr, "Identity: %s\n", id) + } + if tenant != "" { + fmt.Fprintf(os.Stderr, "Tenant: %s\n", tenant) + } + if region, rerr := provider.GetRegion(ctx); rerr == nil && region != "" { + fmt.Fprintf(os.Stderr, "Region: %s\n", region) + } + if exp := provider.Expiry(); !exp.IsZero() { + fmt.Fprintf(os.Stderr, "Expires: %s (%s)\n", exp.Format(time.RFC3339), humanizeUntil(time.Until(exp))) + } + return nil +} + +// describeAuthMethod renders the auth method in user-facing terms. +func describeAuthMethod(m auth.AuthMethod) string { + switch m { + case auth.AuthMethodSSO: + return "SSO (browser login)" + case auth.AuthMethodClientCredentials: + return "client credentials" + case auth.AuthMethodBasic: + return "API token (legacy)" + default: + return string(m) + } +} + +// humanizeUntil renders a duration-until-expiry as a short, friendly phrase. +func humanizeUntil(d time.Duration) string { + if d <= 0 { + return "expired" + } + switch { + case d < time.Minute: + return "in less than a minute" + case d < time.Hour: + return fmt.Sprintf("in %d minutes", int(d.Minutes())) + case d < 24*time.Hour: + return fmt.Sprintf("in %d hours", int(d.Hours())) + default: + return fmt.Sprintf("in %d days", int(d.Hours()/24)) + } +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index bc3576b4..a0c2d8d1 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "os" + "strings" "sync" "time" @@ -50,6 +51,11 @@ var ( clientSecret string region string + // credFlagsExplicit is set in PersistentPreRunE when the user passed + // --client-id/--client-secret/--token explicitly. It lets those flags + // override a stored SSO token in getAuthProvider. + credFlagsExplicit bool + version = versionDev commit = "none" date = "unknown" @@ -129,6 +135,12 @@ var rootCmd = &cobra.Command{ if !cmd.Flags().Changed("client-secret") { clientSecret = os.Getenv("ARMIS_CLIENT_SECRET") } + // Record whether the user explicitly passed credential flags. When they + // did, those flags take precedence over any stored SSO token (an escape + // hatch for forcing client-credentials/Basic auth without logging out). + credFlagsExplicit = cmd.Flags().Changed("client-id") || + cmd.Flags().Changed("client-secret") || + cmd.Flags().Changed("token") if !cmd.Flags().Changed("region") { region = os.Getenv("ARMIS_REGION") } @@ -396,10 +408,26 @@ func resolveDataPlaneURL(ctx context.Context, authProvider *auth.AuthProvider) s return getAPIBaseURL() } -// getAuthProvider creates an AuthProvider based on the provided credentials. -// Priority: JWT auth (--client-id, --client-secret) > Basic auth (--token) +// getAuthProvider creates an AuthProvider based on the available credentials. +// +// Resolution order (PPSC-1037): +// 1. Stored SSO token (keychain / fallback file) from `armis-cli auth login`, +// unless the user explicitly passed --client-id/--client-secret/--token, +// which act as an escape hatch to force the credential path. +// 2. Client credentials (--client-id/--client-secret or ARMIS_CLIENT_ID/SECRET). +// 3. Legacy --token (Basic auth). +// 4. Otherwise an error pointing at `auth login` / env credentials. +// +// CI/CD is unaffected: with no stored token, resolution falls straight through +// to env-var client credentials exactly as before. func getAuthProvider() (*auth.AuthProvider, error) { - return auth.NewAuthProvider(auth.AuthConfig{ + if !credFlagsExplicit { + if provider, ok := storedAuthProvider(); ok { + return provider, nil + } + } + + provider, err := auth.NewAuthProvider(auth.AuthConfig{ ClientID: clientID, ClientSecret: clientSecret, BaseURL: getAPIBaseURL(), @@ -408,6 +436,50 @@ func getAuthProvider() (*auth.AuthProvider, error) { TenantID: tenantID, Debug: debug, }) + if err != nil { + // Improve the no-credentials message to mention SSO login. + return nil, augmentNoCredentialsError(err) + } + return provider, nil +} + +// storedAuthProvider builds an SSO-backed AuthProvider from a previously stored +// device-flow token, or returns ok=false when none is present (or it cannot be +// used), so callers fall through to credential-based auth. +func storedAuthProvider() (*auth.AuthProvider, bool) { + // The environment key is the resolved API base URL, so each environment + // (prod, dev, a local stack) has its own token entry. This is also where + // the refresh grant is sent, so the token's own issuer is not consulted. + env := getAPIBaseURL() + + store := auth.NewTokenStore() + stored, err := store.Load(env) + if err != nil || stored == nil { + return nil, false + } + + deviceClient, err := auth.NewDeviceClient(env, debug) + if err != nil { + return nil, false + } + provider, err := auth.NewProviderFromStored(store, deviceClient, env, stored) + if err != nil { + return nil, false + } + return provider, true +} + +// augmentNoCredentialsError replaces the auth package's generic +// "authentication required" error with a CLI-friendly, browser-login-first list +// of options. Other errors pass through unchanged. +func augmentNoCredentialsError(err error) error { + if err == nil || !strings.Contains(err.Error(), "authentication required") { + return err + } + return fmt.Errorf("not authenticated — use one of the following options:\n" + + " - run 'armis-cli auth login' to sign in with your company IdP\n" + + " - or set ARMIS_CLIENT_ID / ARMIS_CLIENT_SECRET (or --client-id / --client-secret) for JWT auth\n" + + " - or set ARMIS_API_TOKEN (or --token) for legacy auth") } func getPageLimit() (int, error) { From 425e0a93d55435d81ec1aa0cbde401a2daed89bb Mon Sep 17 00:00:00 2001 From: Dmitry Smirnov Date: Tue, 30 Jun 2026 11:10:36 -0400 Subject: [PATCH 2/5] feat: ARMIS_DEFAULT_LOGIN_METHOD env var to automatically trigger auth login sequence --- internal/cmd/agent_detection_collect.go | 2 +- internal/cmd/auth.go | 2 +- internal/cmd/auth_flow_test.go | 93 ++++++++++++++++++++++++- internal/cmd/auth_login.go | 35 +++++++--- internal/cmd/auth_whoami.go | 2 +- internal/cmd/root.go | 42 ++++++++++- internal/cmd/root_test.go | 2 +- internal/cmd/scan_image.go | 2 +- internal/cmd/scan_repo.go | 2 +- 9 files changed, 160 insertions(+), 22 deletions(-) diff --git a/internal/cmd/agent_detection_collect.go b/internal/cmd/agent_detection_collect.go index 0ba885ad..e0d097b5 100644 --- a/internal/cmd/agent_detection_collect.go +++ b/internal/cmd/agent_detection_collect.go @@ -34,7 +34,7 @@ func runAgentDetectionCollect(cmd *cobra.Command, _ []string) error { ctx, cancel := NewSignalContext() defer cancel() - authProvider, err := getAuthProvider() + authProvider, err := getAuthProvider(ctx) if err != nil { return fmt.Errorf("authentication failed: %w", err) } diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go index 4c19dbc0..5909aa3d 100644 --- a/internal/cmd/auth.go +++ b/internal/cmd/auth.go @@ -56,7 +56,7 @@ func init() { } func runAuth(cmd *cobra.Command, _ []string) error { - provider, err := getAuthProvider() + provider, err := getAuthProvider(cmd.Context()) if err != nil { return fmt.Errorf("authentication failed: %w", err) } diff --git a/internal/cmd/auth_flow_test.go b/internal/cmd/auth_flow_test.go index ba36ae9b..7ee6a660 100644 --- a/internal/cmd/auth_flow_test.go +++ b/internal/cmd/auth_flow_test.go @@ -242,7 +242,7 @@ func TestStoredTokenTakesPrecedence(t *testing.T) { clientID = "env-client" clientSecret = "env-secret" - provider, err := getAuthProvider() + provider, err := getAuthProvider(context.Background()) if err != nil { t.Fatalf("getAuthProvider: %v", err) } @@ -254,7 +254,7 @@ func TestStoredTokenTakesPrecedence(t *testing.T) { func TestNoCredentialsErrorMentionsLogin(t *testing.T) { setupAuthTest(t, "https://moose.armis.com") // No stored token, no credentials. - _, err := getAuthProvider() + _, err := getAuthProvider(context.Background()) if err == nil { t.Fatal("expected error with no credentials") } @@ -263,6 +263,95 @@ func TestNoCredentialsErrorMentionsLogin(t *testing.T) { } } +// TestShouldAutoLoginSSO pins the gating rules: SSO auto-login fires only when +// ARMIS_DEFAULT_AUTH_METHOD=SSO and no other credential is configured. +func TestShouldAutoLoginSSO(t *testing.T) { + tests := []struct { + name string + env string + clientID string + token string + explicit bool + want bool + }{ + {name: "unset", env: "", want: false}, + {name: "sso, no creds", env: "SSO", want: true}, + {name: "sso lowercase", env: "sso", want: true}, + {name: "other value", env: "client-credentials", want: false}, + {name: "sso but client creds present", env: "sso", clientID: "id", want: false}, + {name: "sso but legacy token present", env: "sso", token: "tok", want: false}, + {name: "sso but explicit cred flags", env: "sso", explicit: true, want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setupAuthTest(t, "https://moose.armis.com") + t.Setenv("ARMIS_DEFAULT_AUTH_METHOD", tt.env) + clientID = tt.clientID + token = tt.token + credFlagsExplicit = tt.explicit + + if got := shouldAutoLoginSSO(); got != tt.want { + t.Errorf("shouldAutoLoginSSO() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestAutoLoginSSOTriggersDeviceFlow verifies that, with ARMIS_DEFAULT_AUTH_METHOD=SSO +// and no stored token or credentials, getAuthProvider runs the device flow, +// persists the token, and returns an SSO-backed provider. +func TestAutoLoginSSOTriggersDeviceFlow(t *testing.T) { + jwt := deviceJWT(tenant7, "alice@example.com", "admin", time.Now().Add(time.Hour).Unix()) + var mu sync.Mutex + tokenCalls := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/oauth2/device": + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dev-code", + "user_code": "WDJB-MJHT", + "verification_uri": "https://moose.armis.com/oauth2/device/verify", + "verification_uri_complete": "https://moose.armis.com/oauth2/device/verify?user_code=WDJB-MJHT", + "expires_in": 900, + "interval": 1, + }) + case "/oauth2/token": + mu.Lock() + tokenCalls++ + n := tokenCalls + mu.Unlock() + if n < 2 { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": jwt, "token_type": "Bearer", + "expires_in": 3600, "refresh_token": "refresh-7", + }) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + setupAuthTest(t, srv.URL) + t.Setenv("ARMIS_DEFAULT_AUTH_METHOD", "SSO") + tenantID = tenant7 // device authorization requires a tenant + + provider, err := getAuthProvider(context.Background()) + if err != nil { + t.Fatalf("getAuthProvider with SSO auto-login: %v", err) + } + if provider.AuthMethod() != auth.AuthMethodSSO { + t.Errorf("AuthMethod = %q, want sso", provider.AuthMethod()) + } + if stored, _ := auth.NewTokenStore().Load(srv.URL); stored == nil || stored.RefreshToken != "refresh-7" { + t.Errorf("expected auto-login to persist token, got %+v", stored) + } +} + func TestMain(m *testing.M) { // Ensure no test accidentally spawns a real browser: stub the opener to a // no-op (success) for the whole cmd test binary. diff --git a/internal/cmd/auth_login.go b/internal/cmd/auth_login.go index 992681c1..91957e29 100644 --- a/internal/cmd/auth_login.go +++ b/internal/cmd/auth_login.go @@ -49,42 +49,55 @@ func init() { } func runAuthLogin(cmd *cobra.Command, _ []string) error { + _, err := performDeviceLogin(cmd.Context(), loginClientID, loginOrg) + return err +} + +// performDeviceLogin runs the OAuth2 device-authorization flow end to end: +// request a device code, open (or print) the verification URL, poll until the +// user approves, and persist the resulting tokens. It returns the stored token +// on success. It is shared by the `auth login` command and the +// ARMIS_DEFAULT_AUTH_METHOD=SSO auto-login path in getAuthProvider. +// +// ctx bounds the whole interactive flow, so callers should pass a long-lived +// context (e.g. the command context), not a short per-request timeout. +func performDeviceLogin(ctx context.Context, clientID, org string) (*auth.StoredToken, error) { if tenantID == "" { - return fmt.Errorf("tenant ID required: use --tenant-id flag or ARMIS_TENANT_ID environment variable") + return nil, fmt.Errorf("tenant ID required: use --tenant-id flag or ARMIS_TENANT_ID environment variable") } issuer := getAPIBaseURL() deviceClient, err := auth.NewDeviceClient(issuer, debug) if err != nil { - return fmt.Errorf("failed to initialize login: %w", err) + return nil, fmt.Errorf("failed to initialize login: %w", err) } // Step 1: request a device code. Use a short timeout for this single call. - reqCtx, cancelReq := context.WithTimeout(cmd.Context(), 30*time.Second) - da, err := deviceClient.RequestDeviceCode(reqCtx, loginClientID, tenantID, "") + reqCtx, cancelReq := context.WithTimeout(ctx, 30*time.Second) + da, err := deviceClient.RequestDeviceCode(reqCtx, clientID, tenantID, "") cancelReq() if err != nil { - return fmt.Errorf("failed to start login: %w", err) + return nil, fmt.Errorf("failed to start login: %w", err) } // Step 2: send the user to the verification page. The browser URL carries // the user_code, so the happy path needs no manual entry. --org is appended // as a hint for the verification page to preselect the organization. - browseURL := withOrgHint(da.VerificationURIComplete, loginOrg) + browseURL := withOrgHint(da.VerificationURIComplete, org) opened := auth.OpenBrowser(browseURL) == nil printVerificationInstructions(da, browseURL, opened) // Step 3: poll until approval, expiry, or denial. Bound the wait by the // device code's lifetime. - pollCtx, cancelPoll := context.WithTimeout(cmd.Context(), time.Duration(da.ExpiresIn)*time.Second) + pollCtx, cancelPoll := context.WithTimeout(ctx, time.Duration(da.ExpiresIn)*time.Second) defer cancelPoll() spinner := progress.NewSpinner("Waiting for you to finish signing in…", noProgress) spinner.Start() - stored, err := deviceClient.PollToken(pollCtx, da.DeviceCode, loginClientID, da.Interval) + stored, err := deviceClient.PollToken(pollCtx, da.DeviceCode, clientID, da.Interval) spinner.Stop() if err != nil { - return err + return nil, err } // Step 4: persist the tokens for reuse by the CLI and MCP plugins, keyed by @@ -92,11 +105,11 @@ func runAuthLogin(cmd *cobra.Command, _ []string) error { stored.Issuer = issuer store := auth.NewTokenStore() if err := store.Save(issuer, stored); err != nil { - return fmt.Errorf("signed in, but failed to store credentials: %w", err) + return nil, fmt.Errorf("signed in, but failed to store credentials: %w", err) } printLoginSuccess(stored) - return nil + return stored, nil } // withOrgHint appends an `org` query parameter to the verification URL when an diff --git a/internal/cmd/auth_whoami.go b/internal/cmd/auth_whoami.go index 169c6533..5d48c9a8 100644 --- a/internal/cmd/auth_whoami.go +++ b/internal/cmd/auth_whoami.go @@ -25,7 +25,7 @@ func init() { } func runAuthWhoami(cmd *cobra.Command, _ []string) error { - provider, err := getAuthProvider() + provider, err := getAuthProvider(cmd.Context()) if err != nil { // getAuthProvider already returns a self-describing message (the // no-credentials case lists the sign-in options); don't re-wrap it. diff --git a/internal/cmd/root.go b/internal/cmd/root.go index a0c2d8d1..55d0460e 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -416,17 +416,38 @@ func resolveDataPlaneURL(ctx context.Context, authProvider *auth.AuthProvider) s // which act as an escape hatch to force the credential path. // 2. Client credentials (--client-id/--client-secret or ARMIS_CLIENT_ID/SECRET). // 3. Legacy --token (Basic auth). -// 4. Otherwise an error pointing at `auth login` / env credentials. +// 4. When ARMIS_DEFAULT_AUTH_METHOD=SSO and no credentials are configured, +// trigger an interactive browser login (device flow) instead of erroring. +// 5. Otherwise an error pointing at `auth login` / env credentials. // // CI/CD is unaffected: with no stored token, resolution falls straight through -// to env-var client credentials exactly as before. -func getAuthProvider() (*auth.AuthProvider, error) { +// to env-var client credentials exactly as before. The SSO auto-login in step 4 +// only fires when no other credential is available, so it never overrides +// configured client credentials or a legacy token. +// +// ctx bounds any interactive login triggered here, so callers must pass a +// long-lived context (the command context), not a short per-request timeout. +func getAuthProvider(ctx context.Context) (*auth.AuthProvider, error) { if !credFlagsExplicit { if provider, ok := storedAuthProvider(); ok { return provider, nil } } + // Opt-in: when the user has asked for SSO as the default auth method and no + // other credentials are configured, sign in interactively rather than + // failing. This makes `armis-cli scan ...` self-bootstrap a session on a + // developer machine while leaving credentialed (CI) runs untouched. + if shouldAutoLoginSSO() { + if _, err := performDeviceLogin(ctx, auth.DefaultDeviceClientID, ""); err != nil { + return nil, err + } + if provider, ok := storedAuthProvider(); ok { + return provider, nil + } + return nil, fmt.Errorf("signed in, but no stored session was found for %s", getAPIBaseURL()) + } + provider, err := auth.NewAuthProvider(auth.AuthConfig{ ClientID: clientID, ClientSecret: clientSecret, @@ -443,6 +464,21 @@ func getAuthProvider() (*auth.AuthProvider, error) { return provider, nil } +// shouldAutoLoginSSO reports whether getAuthProvider should start an interactive +// device-flow login. It fires only when ARMIS_DEFAULT_AUTH_METHOD=SSO (case +// insensitive) and no other credentials are configured — no explicit credential +// flags, no client credentials, and no legacy token — so it never shadows a +// working CI/service-account setup. +func shouldAutoLoginSSO() bool { + if !strings.EqualFold(os.Getenv("ARMIS_DEFAULT_AUTH_METHOD"), "sso") { + return false + } + if credFlagsExplicit { + return false + } + return clientID == "" && clientSecret == "" && token == "" +} + // storedAuthProvider builds an SSO-backed AuthProvider from a previously stored // device-flow token, or returns ok=false when none is present (or it cannot be // used), so callers fall through to credential-based auth. diff --git a/internal/cmd/root_test.go b/internal/cmd/root_test.go index 321e462b..b48962b7 100644 --- a/internal/cmd/root_test.go +++ b/internal/cmd/root_test.go @@ -870,7 +870,7 @@ func TestGetAuthProvider_NoCredentials(t *testing.T) { token = "" tenantID = "" - _, err := getAuthProvider() + _, err := getAuthProvider(context.Background()) if err == nil { t.Error("expected error when no credentials are provided") } diff --git a/internal/cmd/scan_image.go b/internal/cmd/scan_image.go index 8fbda903..b9e6a505 100644 --- a/internal/cmd/scan_image.go +++ b/internal/cmd/scan_image.go @@ -74,7 +74,7 @@ var scanImageCmd = &cobra.Command{ return image.ErrRuntimeNotFound } - authProvider, err := getAuthProvider() + authProvider, err := getAuthProvider(cmd.Context()) if err != nil { return err } diff --git a/internal/cmd/scan_repo.go b/internal/cmd/scan_repo.go index 2e9242f2..4347e1a3 100644 --- a/internal/cmd/scan_repo.go +++ b/internal/cmd/scan_repo.go @@ -55,7 +55,7 @@ var scanRepoCmd = &cobra.Command{ return fmt.Errorf("path is not a directory: %s", repoPath) } - authProvider, err := getAuthProvider() + authProvider, err := getAuthProvider(cmd.Context()) if err != nil { return err } From 0ab65c7e67c1326c37a7201398808a3fd8f8e7bd Mon Sep 17 00:00:00 2001 From: Dmitry Smirnov Date: Tue, 30 Jun 2026 11:13:06 -0400 Subject: [PATCH 3/5] remove unused org parameter in login command --- internal/cmd/auth_flow_test.go | 1 - internal/cmd/auth_login.go | 37 ++++++---------------------------- internal/cmd/root.go | 2 +- 3 files changed, 7 insertions(+), 33 deletions(-) diff --git a/internal/cmd/auth_flow_test.go b/internal/cmd/auth_flow_test.go index 7ee6a660..bb9eea09 100644 --- a/internal/cmd/auth_flow_test.go +++ b/internal/cmd/auth_flow_test.go @@ -48,7 +48,6 @@ func setupAuthTest(t *testing.T, serverURL string) { tenantID = origGlobals.tenantID credFlagsExplicit = false noProgress = false - loginOrg = "" loginClientID = auth.DefaultDeviceClientID logoutAll = false }) diff --git a/internal/cmd/auth_login.go b/internal/cmd/auth_login.go index 91957e29..de9e5f40 100644 --- a/internal/cmd/auth_login.go +++ b/internal/cmd/auth_login.go @@ -3,7 +3,6 @@ package cmd import ( "context" "fmt" - "net/url" "os" "time" @@ -13,10 +12,7 @@ import ( "github.com/spf13/cobra" ) -var ( - loginOrg string - loginClientID string -) +var loginClientID string var authLoginCmd = &cobra.Command{ Use: "login", @@ -34,22 +30,18 @@ A tenant is required: pass --tenant-id or set ARMIS_TENANT_ID. If the browser cannot be opened automatically (for example over SSH), the CLI prints a URL and a code to enter manually.`, Example: ` # Sign in interactively - armis-cli auth login --tenant-id my-tenant - - # Skip org selection in the browser - armis-cli auth login --tenant-id my-tenant --org my-company`, + armis-cli auth login --tenant-id my-tenant`, Args: cobra.NoArgs, RunE: runAuthLogin, } func init() { - authLoginCmd.Flags().StringVar(&loginOrg, "org", "", "Organization slug hint to skip org selection in the browser") authLoginCmd.Flags().StringVar(&loginClientID, "client-id", auth.DefaultDeviceClientID, "OAuth2 client ID to authenticate as") authCmd.AddCommand(authLoginCmd) } func runAuthLogin(cmd *cobra.Command, _ []string) error { - _, err := performDeviceLogin(cmd.Context(), loginClientID, loginOrg) + _, err := performDeviceLogin(cmd.Context(), loginClientID) return err } @@ -61,7 +53,7 @@ func runAuthLogin(cmd *cobra.Command, _ []string) error { // // ctx bounds the whole interactive flow, so callers should pass a long-lived // context (e.g. the command context), not a short per-request timeout. -func performDeviceLogin(ctx context.Context, clientID, org string) (*auth.StoredToken, error) { +func performDeviceLogin(ctx context.Context, clientID string) (*auth.StoredToken, error) { if tenantID == "" { return nil, fmt.Errorf("tenant ID required: use --tenant-id flag or ARMIS_TENANT_ID environment variable") } @@ -81,9 +73,8 @@ func performDeviceLogin(ctx context.Context, clientID, org string) (*auth.Stored } // Step 2: send the user to the verification page. The browser URL carries - // the user_code, so the happy path needs no manual entry. --org is appended - // as a hint for the verification page to preselect the organization. - browseURL := withOrgHint(da.VerificationURIComplete, org) + // the user_code, so the happy path needs no manual entry. + browseURL := da.VerificationURIComplete opened := auth.OpenBrowser(browseURL) == nil printVerificationInstructions(da, browseURL, opened) @@ -112,22 +103,6 @@ func performDeviceLogin(ctx context.Context, clientID, org string) (*auth.Stored return stored, nil } -// withOrgHint appends an `org` query parameter to the verification URL when an -// org slug was supplied. A parse failure returns the URL unchanged. -func withOrgHint(rawURL, org string) string { - if org == "" { - return rawURL - } - u, err := url.Parse(rawURL) - if err != nil { - return rawURL - } - q := u.Query() - q.Set("org", org) - u.RawQuery = q.Encode() - return u.String() -} - // printVerificationInstructions tells the user where to authenticate, covering // both the auto-opened-browser case and the manual fallback. func printVerificationInstructions(da *auth.DeviceAuthorization, browseURL string, opened bool) { diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 55d0460e..93504ffe 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -439,7 +439,7 @@ func getAuthProvider(ctx context.Context) (*auth.AuthProvider, error) { // failing. This makes `armis-cli scan ...` self-bootstrap a session on a // developer machine while leaving credentialed (CI) runs untouched. if shouldAutoLoginSSO() { - if _, err := performDeviceLogin(ctx, auth.DefaultDeviceClientID, ""); err != nil { + if _, err := performDeviceLogin(ctx, auth.DefaultDeviceClientID); err != nil { return nil, err } if provider, ok := storedAuthProvider(); ok { From bc289b1f8ad7d9c893ec1b6a67cd48f63a60094f Mon Sep 17 00:00:00 2001 From: Dmitry Smirnov Date: Tue, 30 Jun 2026 11:28:18 -0400 Subject: [PATCH 4/5] Docs for the new auth method --- README.md | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 963b342c..ca7f47e3 100644 --- a/README.md +++ b/README.md @@ -904,7 +904,12 @@ pipelines: ## Environment Variables -**JWT Authentication (Recommended):** +Pick the authentication method that matches the environment: + +- **CI/CD and other non-interactive environments** — use client credentials (`ARMIS_CLIENT_ID` / `ARMIS_CLIENT_SECRET`). They authenticate without a browser, which is what automated pipelines need. +- **Developer machines and other interactive environments** — use SSO (`ARMIS_DEFAULT_AUTH_METHOD=SSO`). The CLI signs in through your company's identity provider in the browser, so no long-lived secret has to be stored on the machine. + +**Client Credentials (recommended for CI/CD):** | Variable | Description | |----------|-------------| @@ -912,14 +917,22 @@ pipelines: | `ARMIS_CLIENT_SECRET` | Client secret for JWT authentication | | `ARMIS_REGION` | Armis cloud region (equivalent to `--region` flag) | -When using JWT authentication, the tenant ID is automatically extracted from the token. +When using client credentials, the tenant ID is automatically extracted from the token. + +**SSO (recommended for interactive use):** + +| Variable | Description | +|----------|-------------| +| `ARMIS_DEFAULT_AUTH_METHOD` | Set to `SSO` to sign in through your company's configured identity provider when no other credentials are present (requires `ARMIS_TENANT_ID` or `--tenant-id`) | + +You can also sign in explicitly at any time with `armis-cli auth login`; setting `ARMIS_DEFAULT_AUTH_METHOD=SSO` just triggers that sign-in automatically on the first command that needs credentials. **Basic Authentication (Legacy):** | Variable | Description | |----------|-------------| | `ARMIS_API_TOKEN` | API token for Basic authentication | -| `ARMIS_TENANT_ID` | Tenant identifier (required only with Basic auth) | +| `ARMIS_TENANT_ID` | Tenant identifier (required with Basic auth or SSO) | **General:** From 1c87a04e925b7d4d53da7cb1b2725a2a4b0de5ee Mon Sep 17 00:00:00 2001 From: Dmitry Smirnov Date: Tue, 30 Jun 2026 11:49:52 -0400 Subject: [PATCH 5/5] fix tests on windows, document access management for .armis/.sessions --- README.md | 1 + internal/auth/tokenstore.go | 9 ++++++--- internal/auth/tokenstore_test.go | 10 +++++++--- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ca7f47e3..f9cc43d1 100644 --- a/README.md +++ b/README.md @@ -955,6 +955,7 @@ You can also sign in explicitly at any time with `armis-cli auth login`; setting - Use JWT authentication (client ID/secret) for production — it supports automatic token refresh and does not require a separate tenant ID - Rotate credentials periodically - Credentials are never logged or exposed in output + - SSO session tokens (`armis-cli auth login`) are stored per-user in `~/.armis/.sessions` — owner-only (`0600`) on macOS/Linux, protected by the user-profile ACL on Windows - **Secure Transport**: All API communication uses HTTPS - **Automatic Cleanup**: Temporary files are cleaned up after use - **CI Detection**: Progress bars automatically disabled in CI environments diff --git a/internal/auth/tokenstore.go b/internal/auth/tokenstore.go index 15af5942..81b70115 100644 --- a/internal/auth/tokenstore.go +++ b/internal/auth/tokenstore.go @@ -24,9 +24,12 @@ import ( // A plain file (not the OS keychain) is the deliberate choice: the MCP plugins // are Python, and the refresh-token rotation + reuse-detection on the backend // requires a SINGLE source of truth (a divergent second store would replay a -// rotated token and get the whole token family revoked). The file is 0600 in a -// 0700 directory; protection at rest relies on the OS account + full-disk -// encryption (FileVault/BitLocker/LUKS), matching the AWS/gcloud/kubectl model. +// rotated token and get the whole token family revoked). This matches the +// AWS/gcloud/kubectl/gh model of a per-user credential file. +// +// At rest: Unix writes 0600 in a 0700 ~/.armis. On Windows those mode bits are a +// no-op (NTFS uses ACLs; os.Stat reports 0666); confidentiality there relies on +// the %USERPROFILE% ACL ~/.armis inherits, same as the tools above. // // FILE SHAPE — a JSON array of per-environment entries, so one dev machine can // hold tokens for several Armis environments at once (prod, dev, a local stack): diff --git a/internal/auth/tokenstore_test.go b/internal/auth/tokenstore_test.go index 61050246..85bf9a46 100644 --- a/internal/auth/tokenstore_test.go +++ b/internal/auth/tokenstore_test.go @@ -3,6 +3,7 @@ package auth import ( "os" "path/filepath" + "runtime" "testing" "time" ) @@ -35,14 +36,17 @@ func TestTokenStoreRoundTrip(t *testing.T) { t.Fatalf("Save: %v", err) } - // File should exist with 0600 perms in a 0700 dir. path := filepath.Join(dir, tokenStoreFileName) info, err := os.Stat(path) if err != nil { t.Fatalf("expected token file: %v", err) } - if perm := info.Mode().Perm(); perm != 0o600 { - t.Errorf("token file perm = %o, want 600", perm) + // Must be 0600 on Unix. On Windows mode bits are a no-op (os.Stat reports + // 0666); confidentiality relies on the profile ACL — see tokenstore.go. + if runtime.GOOS != "windows" { + if perm := info.Mode().Perm(); perm != 0o600 { + t.Errorf("token file perm = %o, want 600", perm) + } } got, err := store.Load(envProd)