Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/google-auth/google/auth/aio/transport/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ async def _do_configure():
self._auth_request = AiohttpRequest(session=new_session)

await old_auth_request.close()
else:
self._is_mtls = False

except (
exceptions.ClientCertError,
Expand Down
7 changes: 2 additions & 5 deletions packages/google-auth/google/auth/transport/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from google.auth import exceptions
from google.auth.transport import _mtls_helper
from google.auth.transport import mtls
from google.oauth2 import service_account

try:
Expand Down Expand Up @@ -295,11 +296,7 @@ def __init__(self):
if not use_client_cert:
self._is_mtls = False
else:
# Load client SSL credentials.
metadata_path = _mtls_helper._check_config_path(
_mtls_helper.CONTEXT_AWARE_METADATA_PATH
)
self._is_mtls = metadata_path is not None
self._is_mtls = mtls.has_default_client_cert_source()

@property
def ssl_credentials(self):
Expand Down
1 change: 1 addition & 0 deletions packages/google-auth/google/auth/transport/urllib3.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def configure_mtls_channel(self, client_cert_callback=None):
self._cached_cert = cert
else:
self.http = _make_default_http()
self._is_mtls = False
except (
exceptions.ClientCertError,
ImportError,
Expand Down
37 changes: 37 additions & 0 deletions packages/google-auth/tests/transport/aio/test_sessions_mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from google.auth import exceptions
from google.auth.aio import credentials
from google.auth.aio import transport
from google.auth.aio.transport import sessions

# This is the valid "workload" format the library expects
Expand Down Expand Up @@ -140,3 +141,39 @@ def mock_callback():
await session.configure_mtls_channel(client_cert_callback=mock_callback)

assert session._is_mtls is True

@pytest.mark.asyncio
async def test_configure_mtls_channel_custom_request(self):
"""
Tests that if _auth_request is not an AiohttpRequest, _is_mtls is set to False
because we can't configure the custom request with mTLS.
"""
with mock.patch.dict(
os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}
), mock.patch("os.path.exists") as mock_exists, mock.patch(
"builtins.open", mock.mock_open(read_data=json.dumps(VALID_WORKLOAD_CONFIG))
), mock.patch(
"google.auth.aio.transport.mtls.get_client_cert_and_key"
) as mock_helper, mock.patch(
"google.auth.aio.transport.mtls.make_client_cert_ssl_context"
) as mock_make_context:
mock_exists.return_value = True
mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data")

mock_context = mock.Mock(spec=ssl.SSLContext)
mock_make_context.return_value = mock_context

mock_creds = mock.AsyncMock(spec=credentials.Credentials)
mock_auth_request = mock.AsyncMock(spec=transport.Request)
session = sessions.AsyncAuthorizedSession(
mock_creds, auth_request=mock_auth_request
)

await session.configure_mtls_channel()

# If the request handler is not an AiohttpRequest, the library cannot configure
# the connection to use mTLS, so _is_mtls must be False to reflect this unconfigured state.
assert session._is_mtls is False
mock_make_context.assert_called_once_with(
b"fake_cert_data", b"fake_key_data"
)
35 changes: 35 additions & 0 deletions packages/google-auth/tests/transport/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,41 @@ def test_get_client_ssl_credentials_success(
certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
)

@mock.patch(
"google.auth.transport.mtls.has_default_client_cert_source", autospec=True
)
def test_get_client_ssl_credentials_workload_cert(
self,
mock_has_default_client_cert_source,
mock_check_config_path,
mock_load_json_file,
mock_get_client_ssl_credentials,
mock_ssl_channel_credentials,
):
# Mock that context-aware metadata does not exist, but workload cert config does.
mock_check_config_path.return_value = None
mock_has_default_client_cert_source.return_value = True
mock_get_client_ssl_credentials.return_value = (
True,
PUBLIC_CERT_BYTES,
PRIVATE_KEY_BYTES,
None,
)

with mock.patch.dict(
os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
):
ssl_credentials = google.auth.transport.grpc.SslCredentials()

# If a workload certificate config exists on the device (and use_client_cert is true),
# is_mtls must be True and get_client_ssl_credentials should be invoked.
assert ssl_credentials.ssl_credentials is not None
assert ssl_credentials.is_mtls
Comment on lines +495 to +500

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ssl_credentials = google.auth.transport.grpc.SslCredentials()
# If a workload certificate config exists on the device (and use_client_cert is true),
# is_mtls must be True and get_client_ssl_credentials should be invoked.
assert ssl_credentials.ssl_credentials is not None
assert ssl_credentials.is_mtls
ssl_credentials = google.auth.transport.grpc.SslCredentials()
# If a workload certificate config exists on the device (and use_client_cert is true),
# is_mtls must be True and get_client_ssl_credentials should be invoked.
assert ssl_credentials.ssl_credentials is not None
assert ssl_credentials.is_mtls

mock_get_client_ssl_credentials.assert_called_once()
mock_ssl_channel_credentials.assert_called_once_with(
certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
)

def test_get_client_ssl_credentials_without_client_cert_env(
self,
mock_check_config_path,
Expand Down
3 changes: 3 additions & 0 deletions packages/google-auth/tests/transport/test_urllib3.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def test_configure_mtls_channel_non_mtls(
is_mtls = authed_http.configure_mtls_channel()

assert not is_mtls
# If client certificate and key are not found, the transport falls back to
# a standard connection. _is_mtls must be False to reflect this fallback state.
assert authed_http._is_mtls is False
mock_get_client_cert_and_key.assert_called_once()
mock_make_mutual_tls_http.assert_not_called()

Expand Down
Loading