Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
import asyncio
import copy
import datetime
import functools
import inspect
import logging
import os
import threading
from typing import NamedTuple, Optional, TYPE_CHECKING

from google.auth import _helpers
from google.auth import environment_vars

if TYPE_CHECKING:
import google.auth.credentials
Expand All @@ -34,25 +31,6 @@
_LOGGER = logging.getLogger(__name__)


@functools.lru_cache()
def is_regional_access_boundary_enabled():
"""Checks if Regional Access Boundary is enabled via environment variable.

The environment variable is interpreted as a boolean with the following
(case-insensitive) rules:
- "true", "1" are considered true.
- Any other value (or unset) is considered false.

Returns:
bool: True if Regional Access Boundary is enabled, False otherwise.
"""
value = os.environ.get(environment_vars.GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED)
if value is None:
return False

return value.lower() in ("true", "1")


# The default lifetime for a cached Regional Access Boundary.
DEFAULT_REGIONAL_ACCESS_BOUNDARY_TTL = datetime.timedelta(hours=6)

Expand Down
8 changes: 3 additions & 5 deletions packages/google-auth/google/auth/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,11 +841,9 @@ def from_info(cls, info, **kwargs):
Raises:
ValueError: For invalid parameters.
"""
aws_security_credentials_supplier = info.get(
"aws_security_credentials_supplier"
)
kwargs.update(
{"aws_security_credentials_supplier": aws_security_credentials_supplier}
kwargs.setdefault(
"aws_security_credentials_supplier",
info.get("aws_security_credentials_supplier"),
)
return super(Credentials, cls).from_info(info, **kwargs)

Expand Down
17 changes: 8 additions & 9 deletions packages/google-auth/google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,13 @@ def _is_regional_endpoint(self, url):
try:
# Do not perform a lookup if the request is for a regional endpoint.
hostname = urlparse(url).hostname
if hostname and (
hostname.endswith(".rep.googleapis.com")
or hostname.endswith(".rep.sandbox.googleapis.com")
if hostname and hostname.endswith(
Comment thread
lsirac marked this conversation as resolved.
(
".rep.googleapis.com",
".rep.sandbox.googleapis.com",
".rep.mtls.googleapis.com",
".rep.mtls.sandbox.googleapis.com",
)
):
return True
except (ValueError, TypeError, AttributeError):
Expand Down Expand Up @@ -484,16 +488,11 @@ def _maybe_start_regional_access_boundary_refresh(self, request, url):
def _is_regional_access_boundary_lookup_required(self):
"""Checks if a Regional Access Boundary lookup is required.

A lookup is required if the feature is enabled via an environment
variable and the universe domain is supported.
A lookup is required if the universe domain is supported.

Returns:
bool: True if a Regional Access Boundary lookup is required, False otherwise.
"""
# Check if the feature is enabled.
if not _regional_access_boundary_utils.is_regional_access_boundary_enabled():
return False

# Skip for non-default universe domains.
if self.universe_domain != DEFAULT_UNIVERSE_DOMAIN:
return False
Expand Down
3 changes: 0 additions & 3 deletions packages/google-auth/google/auth/environment_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@
AWS_REGION = "AWS_REGION"
AWS_DEFAULT_REGION = "AWS_DEFAULT_REGION"

GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED = "GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED"
"""Environment variable controlling whether to enable trust boundary feature.
The default value is false. Users have to explicitly set this value to true."""

GOOGLE_API_CERTIFICATE_CONFIG = "GOOGLE_API_CERTIFICATE_CONFIG"
"""Environment variable defining the location of Google API certificate config
Expand Down
11 changes: 11 additions & 0 deletions packages/google-auth/google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,17 @@ def _maybe_start_regional_access_boundary_refresh(self, request, url):
HTTP requests.
url (str): The URL of the request.
"""
if (
self._should_initialize_impersonated_credentials()
and self.service_account_email
):
self._impersonated_credentials = self._initialize_impersonated_credentials()
if getattr(self, "token", None):
self._impersonated_credentials.token = self.token
if getattr(self, "expiry", None):
self._impersonated_credentials.expiry = self.expiry
self._rab_manager = self._impersonated_credentials._rab_manager

if getattr(self, "_impersonated_credentials", None):
self._impersonated_credentials._maybe_start_regional_access_boundary_refresh(
request, url
Expand Down
3 changes: 1 addition & 2 deletions packages/google-auth/google/auth/identity_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,7 @@ def from_info(cls, info, **kwargs):
Raises:
ValueError: For invalid parameters.
"""
subject_token_supplier = info.get("subject_token_supplier")
kwargs.update({"subject_token_supplier": subject_token_supplier})
kwargs.setdefault("subject_token_supplier", info.get("subject_token_supplier"))
return super(Credentials, cls).from_info(info, **kwargs)

@classmethod
Expand Down
50 changes: 35 additions & 15 deletions packages/google-auth/tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import base64
import datetime
import re
from unittest import mock

import pytest # type: ignore
Expand Down Expand Up @@ -206,6 +207,7 @@ def test_before_request_refreshes(self, get):
"access_token": "token",
"expires_in": 500,
},
"googleapis.com",
]

# Credentials should start as invalid
Expand Down Expand Up @@ -252,7 +254,11 @@ def test_with_universe_domain(self):
assert creds.universe_domain == "universe_domain"
assert creds._universe_domain_cached

def test_token_usage_metrics(self):
@mock.patch(
"google.auth.compute_engine._metadata.get_universe_domain",
return_value="googleapis.com",
)
def test_token_usage_metrics(self, mock_get_universe_domain):
self.credentials.token = "token"
self.credentials.expiry = None

Expand Down Expand Up @@ -410,11 +416,7 @@ def test_build_regional_access_boundary_lookup_url_no_email(
url = creds._build_regional_access_boundary_lookup_url()
assert url is None

@mock.patch(
"google.auth._regional_access_boundary_utils.is_regional_access_boundary_enabled",
return_value=True,
)
def test_is_regional_access_boundary_lookup_required(self, mock_enabled):
def test_is_regional_access_boundary_lookup_required(self):
creds = self.credentials
creds._universe_domain_cached = True

Expand Down Expand Up @@ -442,15 +444,11 @@ def test_build_regional_access_boundary_lookup_url_with_invalid_email(self):
url = creds._build_regional_access_boundary_lookup_url()
assert url is None

@mock.patch(
"google.auth._regional_access_boundary_utils.is_regional_access_boundary_enabled",
return_value=True,
)
@mock.patch(
"google.auth.compute_engine._metadata.get_service_account_info", autospec=True
)
def test_regional_access_boundary_disabled_state_transitions(
self, mock_get_service_account_info, mock_enabled
self, mock_get_service_account_info
):
mock_get_service_account_info.return_value = {
"email": "spiffe://trust-domain/ns/ns/sa/sa",
Expand Down Expand Up @@ -769,6 +767,15 @@ def test_with_target_audience_integration(self):
json={},
)

# mock allowedLocations for Regional Access Boundary
responses.add(
responses.GET,
re.compile(r".*/allowedLocations$"),
status=200,
content_type="application/json",
json={"encodedLocations": "0xABC"},
)

# mock token for credentials
responses.add(
responses.GET,
Expand All @@ -787,8 +794,10 @@ def test_with_target_audience_integration(self):
signature = base64.b64encode(b"some-signature").decode("utf-8")
responses.add(
responses.POST,
"https://iamcredentials.googleapis.com/v1/projects/-/"
"serviceAccounts/service-account@example.com:signBlob",
re.compile(
r"https://iamcredentials\.(mtls\.)?googleapis\.com/v1/projects/-/"
r"serviceAccounts/service-account@example\.com:signBlob"
),
status=200,
content_type="application/json",
json={"keyId": "some-key-id", "signedBlob": signature},
Expand Down Expand Up @@ -951,12 +960,23 @@ def test_with_quota_project_integration(self):
json={},
)

# mock allowedLocations for Regional Access Boundary
responses.add(
responses.GET,
re.compile(r".*/allowedLocations$"),
status=200,
content_type="application/json",
json={"encodedLocations": "0xABC"},
)

# mock sign blob endpoint
signature = base64.b64encode(b"some-signature").decode("utf-8")
responses.add(
responses.POST,
"https://iamcredentials.googleapis.com/v1/projects/-/"
"serviceAccounts/service-account@example.com:signBlob",
re.compile(
r"https://iamcredentials\.(mtls\.)?googleapis\.com/v1/projects/-/"
r"serviceAccounts/service-account@example\.com:signBlob"
),
status=200,
content_type="application/json",
json={"keyId": "some-key-id", "signedBlob": signature},
Expand Down
Loading
Loading