Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions auth/authorization_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ type AuthorizationCodeHandlerConfig struct {
// See [AuthorizationCodeFetcher] for details.
AuthorizationCodeFetcher AuthorizationCodeFetcher

// Scopes optionally restricts the requested scopes to this allowlist,
// intersecting it with the scopes discovered from metadata/challenge
// (preserving order). This lets a client drop an advertised scope it does
// not want — e.g. Gmail's gmail.metadata, which the Gmail API refuses to
// combine with the search "q" parameter even alongside gmail.readonly. An
// allowlist matching nothing is ignored so the flow never requests an empty
// set; offline_access is exempt.
Scopes []string

// RequestRefreshToken indicates that the client intends to use refresh
// tokens and is capable of storing them securely.
//
Expand Down Expand Up @@ -293,6 +302,12 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ
requestedScopes = prm.ScopesSupported
}

// Apply the configured allowlist before offline_access and the step-up
// union below so neither is affected.
if len(h.config.Scopes) > 0 {
requestedScopes = intersectScopes(requestedScopes, h.config.Scopes)
}

// SEP-2207: when the client desires refresh tokens and the Authorization
// Server advertises offline_access support, add it to the requested scopes.
if h.config.RequestRefreshToken &&
Expand Down Expand Up @@ -358,6 +373,22 @@ func scopesFromChallenges(cs []oauthex.Challenge) []string {
return nil
}

// intersectScopes returns the members of scopes also in allow, preserving order.
// If nothing matches it returns scopes unchanged, so a bad allowlist cannot
// leave the client requesting no scopes at all.
func intersectScopes(scopes, allow []string) []string {
keep := make([]string, 0, len(scopes))
for _, s := range scopes {
if slices.Contains(allow, s) {
keep = append(keep, s)
}
}
if len(keep) == 0 {
return scopes
}
return keep
}

// errorFromChallenges returns the error from the given "WWW-Authenticate" header challenges.
// It only looks at challenges with the "Bearer" scheme.
func errorFromChallenges(cs []oauthex.Challenge) string {
Expand Down
117 changes: 117 additions & 0 deletions auth/authorization_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,123 @@ func TestAuthorize_OfflineAccessScope(t *testing.T) {
}
}

func TestAuthorize_ScopesAllowlist(t *testing.T) {
// advertised is delivered via the WWW-Authenticate "scope" challenge, which
// the handler treats as the requested scopes before applying the allowlist.
const advertised = "gmail.metadata gmail.readonly gmail.compose"
tests := []struct {
name string
allowlist []string
wantScopes string
}{
{
name: "FiltersToAllowlist",
allowlist: []string{"gmail.compose", "gmail.readonly"},
wantScopes: "gmail.readonly gmail.compose",
},
{
name: "EmptyIntersectionFailsOpen",
allowlist: []string{"drive.readonly"},
wantScopes: advertised,
},
{
name: "NoAllowlistLeavesScopesUnchanged",
allowlist: nil,
wantScopes: advertised,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{
RegistrationConfig: &oauthtest.RegistrationConfig{
PreregisteredClients: map[string]oauthtest.ClientInfo{
"test_client_id": {
Secret: "test_client_secret",
RedirectURIs: []string{"http://localhost:12345/callback"},
},
},
},
})
authServer.Start(t)

resourceMux := http.NewServeMux()
resourceServer := httptest.NewServer(resourceMux)
t.Cleanup(resourceServer.Close)
resourceURL := resourceServer.URL + "/resource"
resourceMux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{
Resource: resourceURL,
AuthorizationServers: []string{authServer.URL()},
}))

var capturedAuthURL string
handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{
RedirectURL: "http://localhost:12345/callback",
PreregisteredClient: &oauthex.ClientCredentials{
ClientID: "test_client_id",
ClientSecretAuth: &oauthex.ClientSecretAuth{ClientSecret: "test_client_secret"},
},
Scopes: tt.allowlist,
AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) {
capturedAuthURL = args.URL
return nil, fmt.Errorf("stop after capturing URL")
},
})
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler failed: %v", err)
}

req := httptest.NewRequest(http.MethodGet, resourceURL, nil)
resp := &http.Response{
StatusCode: http.StatusUnauthorized,
Header: make(http.Header),
Body: http.NoBody,
Request: req,
}
resp.Header.Set("WWW-Authenticate", fmt.Sprintf(
"Bearer resource_metadata=%s/.well-known/oauth-protected-resource/resource, scope=%q",
resourceServer.URL, advertised))

handler.Authorize(context.Background(), req, resp)

if capturedAuthURL == "" {
t.Fatal("AuthorizationCodeFetcher was not called")
}
u, err := url.Parse(capturedAuthURL)
if err != nil {
t.Fatalf("failed to parse captured auth URL: %v", err)
}
// Compare as a set: UnionScopes (applied downstream) returns map keys,
// so the order of the requested scope parameter is not deterministic.
got := strings.Fields(u.Query().Get("scope"))
want := strings.Fields(tt.wantScopes)
slices.Sort(got)
slices.Sort(want)
if !slices.Equal(got, want) {
t.Errorf("requested scopes = %v, want %v (any order)", got, want)
}
})
}
}

func TestIntersectScopes(t *testing.T) {
for _, tt := range []struct {
name string
scopes, allow []string
want []string
}{
{"filters and preserves scopes order", []string{"a", "b", "c"}, []string{"c", "a"}, []string{"a", "c"}},
{"empty intersection returns input unchanged", []string{"a", "b"}, []string{"z"}, []string{"a", "b"}},
{"empty allow returns input unchanged", []string{"a", "b"}, nil, []string{"a", "b"}},
} {
t.Run(tt.name, func(t *testing.T) {
if got := intersectScopes(tt.scopes, tt.allow); !slices.Equal(got, tt.want) {
t.Errorf("intersectScopes(%v, %v) = %v, want %v", tt.scopes, tt.allow, got, tt.want)
}
})
}
}

// validConfig for test to create an AuthorizationCodeHandler using its constructor.
// Values that are relevant to the test should be set explicitly.
func validConfig() *AuthorizationCodeHandlerConfig {
Expand Down
Loading