From 94d8949c7751e578529f406b5e44c576e00c4822 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 25 May 2026 12:52:33 -0600 Subject: [PATCH 1/2] First attempt at migrating in the latest changes --- bluesky_httpserver/_authentication.py | 139 +++++++++++++----- bluesky_httpserver/authentication/__init__.py | 2 + bluesky_httpserver/authenticators.py | 135 ++++++++++++++++- bluesky_httpserver/database/core.py | 9 ++ bluesky_httpserver/routers/core_api.py | 15 +- bluesky_httpserver/schemas.py | 1 + .../tests/test_auth_for_websockets.py | 24 ++- .../tests/test_authenticators.py | 84 +++++++++-- .../tests/test_oidc_authenticators.py | 63 +++++++- 9 files changed, 409 insertions(+), 63 deletions(-) diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index c1144f5..992c5fb 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -4,7 +4,7 @@ import uuid as uuid_module import warnings from datetime import datetime, timedelta -from typing import Optional +from typing import Any, Optional from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket from fastapi.openapi.models import APIKey, APIKeyIn @@ -36,6 +36,7 @@ from .database import orm from .database.core import ( create_user, + get_or_create_principal, latest_principal_activity, lookup_valid_api_key, lookup_valid_pending_session_by_device_code, @@ -140,28 +141,53 @@ def create_refresh_token(session_id, secret_key, expires_delta): return encoded_jwt -def decode_token(token, secret_keys): - credentials_exception = HTTPException( - status_code=401, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) +def _decode_token_with_secret_keys(token, secret_keys): # The first key in settings.secret_keys is used for *encoding*. # All keys are tried for *decoding* until one works or they all - # fail. They supports key rotation. + # fail. They support key rotation. for secret_key in secret_keys: try: payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) - break + return payload except ExpiredSignatureError: # Do not let this be caught below with the other JWTError types. raise except JWTError: # Try the next key in the key rotation. continue - else: - raise credentials_exception - return payload + return None + + +def decode_token(token, secret_keys, proxied_authenticator=None): + credentials_exception = HTTPException( + status_code=401, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + payload = _decode_token_with_secret_keys(token, secret_keys) + if payload is not None: + return payload + if proxied_authenticator is not None: + return proxied_authenticator.decode_token(token) + raise credentials_exception + + +def _extract_scopes(decoded_access_token: dict[str, Any]) -> set[str]: + if "scp" in decoded_access_token: + scp = decoded_access_token["scp"] + return set(scp) if isinstance(scp, list) else set(scp.split(" ")) + if "scope" in decoded_access_token: + return set(decoded_access_token["scope"].split(" ")) + return set() + + +def _get_proxied_authenticator(authenticators): + if not authenticators: + return None + for authenticator in authenticators.values(): + if hasattr(authenticator, "oauth2_schema") and hasattr(authenticator, "decode_token"): + return authenticator + return None async def get_api_key( @@ -280,27 +306,49 @@ def get_current_principal( request.state.cookies_to_set.append({"key": API_KEY_COOKIE_NAME, "value": api_key}) elif access_token is not None: try: - payload = decode_token(access_token, settings.secret_keys) + payload = decode_token( + access_token, + settings.secret_keys, + _get_proxied_authenticator(authenticators), + ) except ExpiredSignatureError: raise HTTPException( status_code=401, detail="Access token has expired. Refresh token.", headers=headers_for_401, ) - principal = schemas.Principal( - uuid=uuid_module.UUID(hex=payload["sub"]), - type=payload["sub_typ"], - identities=[ - schemas.Identity(id=identity["id"], provider=identity["idp"]) for identity in payload["ids"] - ], - ) - - # scopes = payload["scp"] - - # Combine scopes for all identities (it is expected to be only one identity). - ids = [_["id"] for _ in payload["ids"] if _["idp"] in settings.authentication_provider_names] - scopes = set.union(*[api_access_manager.get_user_scopes(_) for _ in ids]) + token_scopes = _extract_scopes(payload) + if "sub_typ" in payload and "ids" in payload: + principal = schemas.Principal( + uuid=uuid_module.UUID(hex=payload["sub"]), + type=payload["sub_typ"], + identities=[ + schemas.Identity(id=identity["id"], provider=identity["idp"]) for identity in payload["ids"] + ], + ) + ids = [ + _["id"] + for _ in payload["ids"] + if (_["idp"] in settings.authentication_provider_names) + and api_access_manager.is_user_known(_["id"]) + ] + else: + identity_id = payload.get("user") or payload["sub"] + provider = ( + settings.authentication_provider_names[0] + if settings.authentication_provider_names + else _DEFAULT_ANONYMOUS_PROVIDER_NAME + ) + with get_sessionmaker(settings.database_settings)() as db: + principal_orm = get_or_create_principal(db, provider, identity_id) + principal = schemas.Principal( + uuid=principal_orm.uuid, + type="user", + identities=[schemas.Identity(id=identity_id, provider=provider)], + ) + ids = [identity_id] if api_access_manager.is_user_known(identity_id) else [] + scopes = set.union(*[api_access_manager.get_user_scopes(_) for _ in ids]) if ids else set(token_scopes) roles_sets = [api_access_manager.get_user_roles(_) for _ in ids] roles = set.union(*roles_sets) if roles_sets else set() @@ -361,7 +409,7 @@ def get_current_principal( return principal -def get_current_principal_websocket( +async def get_current_principal_websocket( websocket: WebSocket, scopes: str, ): @@ -373,13 +421,31 @@ def get_current_principal_websocket( auth_header = websocket.headers.get("Authorization", "") access_token, api_key = None, None - # Currently we do not support authentication with tokens - # if auth_header.startswith("Bearer "): - # access_token = auth_header[len("Bearer") :].strip() - if auth_header.startswith("ApiKey "): - api_key = auth_header[len("ApiKey") :].strip() + scheme, param = get_authorization_scheme_param(auth_header) + if scheme.lower() == "bearer": + access_token = param + elif scheme.lower() == "apikey": + api_key = param + + if access_token is None: + access_token = websocket.query_params.get("access_token") + if api_key is None: + api_key = websocket.query_params.get("api_key") principal = None + websocket.state.already_accepted = False + no_credentials = (access_token is None) and (api_key is None) + if no_credentials and not settings.allow_anonymous_access: + try: + await websocket.accept() + websocket.state.already_accepted = True + message = await asyncio.wait_for(websocket.receive_json(), timeout=1) + if isinstance(message, dict) and message.get("type") == "auth": + access_token = message.get("access_token") + api_key = message.get("api_key") + except Exception: + return None + try: principal = get_current_principal( request=websocket, @@ -558,11 +624,14 @@ async def authorize_redirect( """Redirect browser to OAuth provider for authentication.""" redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code" + requested_scopes = {"openid", "offline_access"} + requested_scopes.update(getattr(authenticator, "extra_scopes", [])) params = { "client_id": authenticator.client_id, "response_type": "code", - "scope": "openid profile email", + "scope": " ".join(sorted(requested_scopes)), "redirect_uri": redirect_uri, + "prompt": "login", } if state: params["state"] = state @@ -595,7 +664,9 @@ async def device_code_authorize( params={ "client_id": authenticator.client_id, "response_type": "code", - "scope": "openid profile email", + "scope": " ".join( + sorted({"openid", "offline_access", *getattr(authenticator, "extra_scopes", [])}) + ), "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code", "state": pending_session["user_code"].replace("-", ""), } diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py index 85d835e..3475cd1 100644 --- a/bluesky_httpserver/authentication/__init__.py +++ b/bluesky_httpserver/authentication/__init__.py @@ -1,4 +1,5 @@ from .._authentication import ( + _extract_scopes, base_authentication_router, build_auth_code_route, build_authorize_route, @@ -21,6 +22,7 @@ "ExternalAuthenticator", "InternalAuthenticator", "UserSessionState", + "_extract_scopes", "get_current_principal", "get_current_principal_websocket", "base_authentication_router", diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index a58fedf..dfb4466 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -4,9 +4,10 @@ import logging import re import secrets +import uuid from collections.abc import Iterable from datetime import timedelta -from typing import Any, List, Mapping, Optional, cast +from typing import Any, Dict, List, Mapping, Optional, cast import httpx from cachetools import TTLCache, cached @@ -189,17 +190,18 @@ def device_authorization_endpoint(self) -> str: def end_session_endpoint(self) -> str: return cast(str, self._config_from_oidc_url.get("end_session_endpoint")) - @cached(TTLCache(maxsize=1, ttl=timedelta(days=7).total_seconds())) + @cached(TTLCache(maxsize=1, ttl=timedelta(hours=1).total_seconds())) def keys(self) -> List[str]: return httpx.get(self.jwks_uri).raise_for_status().json().get("keys", []) - def decode_token(self, token: str) -> dict[str, Any]: + def decode_token(self, id_token: str, access_token: Optional[str] = None) -> dict[str, Any]: return jwt.decode( - token, + id_token, key=self.keys(), algorithms=self.id_token_signing_alg_values_supported, audience=self._audience, issuer=self.issuer, + access_token=access_token, ) async def authenticate(self, request: Request) -> Optional[UserSessionState]: @@ -223,13 +225,14 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: logger.error("Authentication error: %r", response_body) return None id_token = response_body["id_token"] + access_token = response_body.get("access_token") # NOTE: We decode the id_token, not access_token, because: # 1. The id_token is the OIDC identity assertion meant for the client # 2. Some providers (like Microsoft Entra) return opaque access_tokens # that cannot be decoded with the JWKS keys when the resource is # a first-party Microsoft API (e.g., Graph API with User.Read scope) try: - verified_body = self.decode_token(id_token) + verified_body = self.decode_token(id_token, access_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", @@ -310,18 +313,139 @@ def oauth2_schema(self) -> OAuth2: return self._oidc_bearer +class EntraAuthenticator(ProxiedOIDCAuthenticator): + def __init__( + self, + audience: str, + client_id: str, + well_known_uri: str, + device_flow_client_id: str, + extra_scopes: Optional[List[str]] = None, + confirmation_message: str = "", + scopes_map: Optional[Dict[str, list[str]]] = None, + client_secret: str = "", + redirect_on_success: Optional[str] = None, + ): + self.scopes_map = scopes_map if scopes_map is not None else {} + self.extra_scopes = extra_scopes or [] + super().__init__( + audience, + client_id, + well_known_uri, + device_flow_client_id, + scopes=None, + confirmation_message=confirmation_message, + ) + if client_secret: + self._client_secret = Secret(client_secret) + self.redirect_on_success = redirect_on_success + + @property + def scopes(self): + mapped = set() + for tiled_scopes in self.scopes_map.values(): + mapped.update(tiled_scopes) + return list(mapped) + + @scopes.setter + def scopes(self, value): + pass + + def decode_token(self, id_token: str, access_token: Optional[str] = None) -> dict[str, Any]: + claims = super().decode_token(id_token, access_token) + original_sub = claims.get("sub") + issuer = claims.get("iss", "") + claims["sub"] = uuid.uuid5(uuid.NAMESPACE_URL, f"{issuer}|{original_sub}").hex + claims["entra_sub"] = original_sub + + claims["entra_username"] = ( + claims.get("nameID") or claims.get("preferred_username") or claims.get("upn") or claims.get("email") + ) + + if user := claims.get("entra_username"): + user = user.strip() + if "\\" in user: + user = user.rsplit("\\", 1)[-1] + elif "@" in user: + user = user.split("@", 1)[0] + else: + user = original_sub + logger.warning( + "EntraAuthenticator: no human-readable username claim found in token " + "(checked nameID, preferred_username, upn, email). Falling back to Entra sub=%r.", + original_sub, + ) + claims["user"] = user + + scp_raw = claims.get("scp", "") + tiled_scope_set = set() + if scp_raw: + for scope in scp_raw.split(" "): + mapped_scopes = self.scopes_map.get(scope) + if mapped_scopes is None: + logger.warning("Unmapped Entra scope in 'scp': %s", scope) + continue + tiled_scope_set.update(mapped_scopes) + else: + for mapped_scopes in self.scopes_map.values(): + tiled_scope_set.update(mapped_scopes) + claims["scope"] = " ".join(tiled_scope_set) + return claims + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + code = request.query_params.get("code") + if not code: + logger.warning("Authentication failed: No authorization code parameter provided.") + return None + redirect_uri = f"{get_root_url(request)}{request.url.path}" + response = await exchange_code( + self.token_endpoint, + code, + self._client_id, + self._client_secret.get_secret_value(), + redirect_uri, + extra_scopes=self.extra_scopes, + ) + response_body = response.json() + if response.is_error: + logger.error("Authentication error: %r", response_body) + return None + id_token = response_body["id_token"] + access_token = response_body.get("access_token") + refresh_token = response_body.get("refresh_token") + try: + verified_body = self.decode_token(id_token, access_token) + except JWTError: + logger.exception( + "Authentication error. Unverified token: %r", + jwt.get_unverified_claims(id_token), + ) + return None + username = verified_body.get("user") or verified_body["sub"] + state: dict[str, Any] = {} + if access_token: + state["entra_access_token"] = access_token + if refresh_token: + state["entra_refresh_token"] = refresh_token + return UserSessionState(username, state) + + async def exchange_code( token_uri: str, auth_code: str, client_id: str, client_secret: str, redirect_uri: str, + extra_scopes: Optional[List[str]] = None, ) -> httpx.Response: """Method that talks to an IdP to exchange a code for an access_token and/or id_token Args: token_url ([type]): [description] auth_code ([type]): [description] """ + scopes = {"openid", "offline_access"} + if extra_scopes: + scopes.update(extra_scopes) auth_value = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() response = httpx.post( url=token_uri, @@ -331,6 +455,7 @@ async def exchange_code( "redirect_uri": redirect_uri, "code": auth_code, "client_secret": client_secret, + "scope": " ".join(sorted(scopes)), }, headers={"Authorization": f"Basic {auth_value}"}, ) diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py index 52d102f..a394fdd 100644 --- a/bluesky_httpserver/database/core.py +++ b/bluesky_httpserver/database/core.py @@ -209,6 +209,15 @@ def create_user(db, identity_provider, id): return principal +def get_or_create_principal(db, identity_provider, id): + identity = db.query(Identity).filter(Identity.id == id).filter(Identity.provider == identity_provider).first() + if identity is None: + principal = create_user(db, identity_provider, id) + else: + principal = identity.principal + return principal + + def lookup_valid_session(db, session_id): if isinstance(session_id, int): # Old versions of tiled used an integer sid. diff --git a/bluesky_httpserver/routers/core_api.py b/bluesky_httpserver/routers/core_api.py index 397972b..b70ca46 100644 --- a/bluesky_httpserver/routers/core_api.py +++ b/bluesky_httpserver/routers/core_api.py @@ -1140,12 +1140,13 @@ def is_alive(self): @router.websocket("/console_output/ws") async def console_output_ws(websocket: WebSocket, scopes=["read:console"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.console_output_stream.add_queue(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() @@ -1166,12 +1167,13 @@ async def console_output_ws(websocket: WebSocket, scopes=["read:console"]): @router.websocket("/status/ws") async def status_ws(websocket: WebSocket, scopes=["read:monitor"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.system_info_stream.add_queue_status(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() @@ -1193,12 +1195,13 @@ async def status_ws(websocket: WebSocket, scopes=["read:monitor"]): @router.websocket("/info/ws") async def info_ws(websocket: WebSocket, scopes=["read:monitor"]): - principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + principal = await get_current_principal_websocket(websocket=websocket, scopes=scopes) if not principal: await websocket.close(code=4001, reason="Invalid token") return - await websocket.accept() + if not getattr(websocket.state, "already_accepted", False): + await websocket.accept() q = SR.system_info_stream.add_queue_info(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() diff --git a/bluesky_httpserver/schemas.py b/bluesky_httpserver/schemas.py index f1d9fcb..05a1764 100644 --- a/bluesky_httpserver/schemas.py +++ b/bluesky_httpserver/schemas.py @@ -190,6 +190,7 @@ class AboutAuthenticationProvider(pydantic.BaseModel): mode: AuthenticationMode links: Dict[str, str] confirmation_message: Optional[str] = None + extra_scopes: Optional[List[str]] = None class AboutAuthenticationLinks(pydantic.BaseModel): diff --git a/bluesky_httpserver/tests/test_auth_for_websockets.py b/bluesky_httpserver/tests/test_auth_for_websockets.py index 3d26e22..17fb498 100644 --- a/bluesky_httpserver/tests/test_auth_for_websockets.py +++ b/bluesky_httpserver/tests/test_auth_for_websockets.py @@ -50,12 +50,13 @@ class _ReceiveSystemInfoSocket(threading.Thread): save messages to the buffer. """ - def __init__(self, *, endpoint, api_key=None, token=None, **kwargs): + def __init__(self, *, endpoint, api_key=None, token=None, auth_message=None, **kwargs): super().__init__(**kwargs) self.received_data_buffer = [] self._exit = False self._api_key = api_key self._token = token + self._auth_message = auth_message self._endpoint = endpoint def run(self): @@ -69,6 +70,8 @@ def run(self): try: with connect(websocket_uri, additional_headers=additional_headers) as websocket: + if self._auth_message is not None: + websocket.send(json.dumps(self._auth_message)) while not self._exit: try: msg_json = websocket.recv(timeout=0.1, decode=False) @@ -94,7 +97,10 @@ def __del__(self): # fmt: off -@pytest.mark.parametrize("ws_auth_type", ["apikey", "apikey_invalid", "none"]) +@pytest.mark.parametrize( + "ws_auth_type", + ["apikey", "apikey_invalid", "token", "token_invalid", "none", "first_message_apikey", "first_message_token"], +) # fmt: on def test_websocket_auth_01( tmpdir, @@ -135,10 +141,14 @@ def test_websocket_auth_01( ws_params = {"api_key": api_key} elif ws_auth_type == "apikey_invalid": ws_params = {"api_key": "InvalidApiKey"} - # elif ws_auth_type == "token": - # ws_params = {"token": token} - # elif ws_auth_type == "token_invalid": - # ws_params = {"token": "InvalidToken"} + elif ws_auth_type == "token": + ws_params = {"token": token} + elif ws_auth_type == "token_invalid": + ws_params = {"token": "InvalidToken"} + elif ws_auth_type == "first_message_apikey": + ws_params = {"auth_message": {"type": "auth", "api_key": api_key}} + elif ws_auth_type == "first_message_token": + ws_params = {"auth_message": {"type": "auth", "access_token": token}} else: assert False, f"Unknown authentication type: {ws_auth_type!r}" @@ -164,7 +174,7 @@ def test_websocket_auth_01( buffer = rsc.received_data_buffer if ws_auth_type in ("none", "apikey_invalid", "token_invalid"): assert len(buffer) == 0 - elif ws_auth_type in ("apikey", "token"): + elif ws_auth_type in ("apikey", "token", "first_message_apikey", "first_message_token"): assert len(buffer) > 0 for msg in buffer: assert "time" in msg, msg diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index 7b7dd4b..7397550 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -10,8 +10,16 @@ from jose.backends import RSAKey from respx import MockRouter from starlette.datastructures import URL, QueryParams +from starlette.requests import Request -from ..authenticators import LDAPAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator, UserSessionState +from .._authentication import build_authorize_route +from ..authenticators import ( + LDAPAuthenticator, + OIDCAuthenticator, + ProxiedOIDCAuthenticator, + UserSessionState, + exchange_code, +) LDAP_TEST_HOST = os.environ.get("QSERVER_TEST_LDAP_HOST", "localhost") LDAP_TEST_PORT = int(os.environ.get("QSERVER_TEST_LDAP_PORT", "1389")) @@ -233,22 +241,21 @@ async def test_OIDCAuthenticator_mock( mock_request = create_mock_oidc_request({"code": "test-auth-code"}) - def mock_jwt_decode(*args, **kwargs): - return mock_jwt_payload - - def mock_jwk_construct(*args, **kwargs): - class MockJWK: - pass + decode_calls = {} - return MockJWK() + def mock_decode_token(id_token, access_token=None): + decode_calls["id_token"] = id_token + decode_calls["access_token"] = access_token + return mock_jwt_payload - monkeypatch.setattr("jose.jwt.decode", mock_jwt_decode) - monkeypatch.setattr("jose.jwk.construct", mock_jwk_construct) + monkeypatch.setattr(authenticator, "decode_token", mock_decode_token) user_session = await authenticator.authenticate(mock_request) assert user_session is not None assert user_session.user_name == "0009-0008-8698-7745" + assert decode_calls["id_token"] == "mock-id-token" + assert decode_calls["access_token"] == "mock-access-token" @pytest.mark.asyncio @@ -293,3 +300,60 @@ async def test_OIDCAuthenticator_token_exchange_failure( result = await authenticator.authenticate(mock_request) assert result is None + + +@pytest.mark.asyncio +async def test_exchange_code_requests_offline_access(monkeypatch): + captured = {} + + def mock_post(*, url, data, headers): + captured["url"] = url + captured["data"] = data + captured["headers"] = headers + return httpx.Response(200, json={"id_token": "X", "access_token": "Y"}) + + monkeypatch.setattr("httpx.post", mock_post) + + await exchange_code( + token_uri="https://idp.example/token", + auth_code="authcode", + client_id="client-id", + client_secret="client-secret", + redirect_uri="https://server.example/callback", + extra_scopes=["api://example/access_as_user"], + ) + + assert captured["url"] == "https://idp.example/token" + assert set(captured["data"]["scope"].split(" ")) == { + "openid", + "offline_access", + "api://example/access_as_user", + } + + +@pytest.mark.asyncio +async def test_authorize_route_requests_extra_scopes_and_prompt(): + class _Authenticator: + client_id = "test-client" + extra_scopes = ["api://example/access_as_user"] + authorization_endpoint = httpx.URL("https://idp.example/auth") + + route = build_authorize_route(_Authenticator(), "oidc") + request = Request( + { + "type": "http", + "scheme": "http", + "path": "/api/auth/provider/oidc/authorize", + "root_path": "", + "query_string": b"", + "headers": [(b"host", b"localhost:8000")], + "server": ("localhost", 8000), + "client": ("127.0.0.1", 54321), + } + ) + response = await route(request) + location = response.headers["location"] + assert "prompt=login" in location + assert "offline_access" in location + assert "openid" in location + assert "api%3A%2F%2Fexample%2Faccess_as_user" in location diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py index f3249cd..cedf611 100644 --- a/bluesky_httpserver/tests/test_oidc_authenticators.py +++ b/bluesky_httpserver/tests/test_oidc_authenticators.py @@ -10,7 +10,7 @@ from jose.backends import RSAKey from respx import MockRouter -from bluesky_httpserver.authenticators import OIDCAuthenticator, ProxiedOIDCAuthenticator +from bluesky_httpserver.authenticators import EntraAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator @pytest.fixture @@ -217,3 +217,64 @@ def test_proxied_oidc_with_scopes( assert authenticator.scopes == ["openid", "profile", "email"] assert authenticator.device_flow_client_id == "test_cli_client" + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +class TestEntraAuthenticator: + def test_entra_scope_mapping_and_username( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], + ): + private_key, _ = keys + authenticator = EntraAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + scopes_map={"User.Read": ["read:monitor"]}, + ) + token_claims = { + "aud": "test_client", + "exp": time.time() + 1500, + "iat": time.time() - 1, + "iss": "https://example.com/realms/example", + "sub": "entra-subject", + "preferred_username": "alice@example.org", + "scp": "User.Read", + } + encoded = encrypt_token(token_claims, private_key) + decoded = authenticator.decode_token(encoded) + assert decoded["user"] == "alice" + assert set(decoded["scope"].split(" ")) == {"read:monitor"} + assert decoded["entra_sub"] == "entra-subject" + + def test_entra_unmapped_scope_warning( + self, + caplog, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], + ): + private_key, _ = keys + authenticator = EntraAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + scopes_map={"Known.Scope": ["read:monitor"]}, + ) + token_claims = { + "aud": "test_client", + "exp": time.time() + 1500, + "iat": time.time() - 1, + "iss": "https://example.com/realms/example", + "sub": "entra-subject", + "scp": "Unknown.Scope", + } + encoded = encrypt_token(token_claims, private_key) + with caplog.at_level("WARNING"): + decoded = authenticator.decode_token(encoded) + assert decoded["scope"] == "" + assert any("Unmapped Entra scope" in record.message for record in caplog.records) From 6d2576f496c231996f2bf9aea771fe4d51f0b599 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 25 May 2026 13:26:17 -0600 Subject: [PATCH 2/2] Adding aditional test cases --- .../tests/test_auth_websocket_helpers.py | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 bluesky_httpserver/tests/test_auth_websocket_helpers.py diff --git a/bluesky_httpserver/tests/test_auth_websocket_helpers.py b/bluesky_httpserver/tests/test_auth_websocket_helpers.py new file mode 100644 index 0000000..0183e06 --- /dev/null +++ b/bluesky_httpserver/tests/test_auth_websocket_helpers.py @@ -0,0 +1,129 @@ +from types import SimpleNamespace + +import pytest +from fastapi import HTTPException + +from bluesky_httpserver import _authentication as auth + + +class _FakeWebSocket: + def __init__(self, *, headers=None, query_params=None, app=None, first_message=None, receive_error=None): + self.headers = headers or {} + self.query_params = query_params or {} + self.app = app + self.state = SimpleNamespace() + self._first_message = first_message + self._receive_error = receive_error + self.accepted = False + + async def accept(self): + self.accepted = True + + async def receive_json(self): + if self._receive_error is not None: + raise self._receive_error + return self._first_message + + +def _make_app(*, allow_anonymous_access): + settings = SimpleNamespace(allow_anonymous_access=allow_anonymous_access) + return SimpleNamespace( + dependency_overrides={ + auth.get_settings: lambda: settings, + auth.get_authenticators: lambda: {}, + auth.get_api_access_manager: lambda: object(), + } + ) + + +@pytest.mark.asyncio +async def test_websocket_first_message_auth_uses_api_key(monkeypatch): + app = _make_app(allow_anonymous_access=False) + websocket = _FakeWebSocket(app=app, first_message={"type": "auth", "api_key": "secret-key"}) + + expected_principal = object() + + def fake_get_current_principal(**kwargs): + assert kwargs["access_token"] is None + assert kwargs["api_key"] == "secret-key" + return expected_principal + + monkeypatch.setattr(auth, "get_current_principal", fake_get_current_principal) + + principal = await auth.get_current_principal_websocket(websocket=websocket, scopes=["read:monitor"]) + + assert principal is expected_principal + assert websocket.accepted is True + assert websocket.state.already_accepted is True + + +@pytest.mark.asyncio +async def test_websocket_first_message_non_auth_is_rejected(monkeypatch): + app = _make_app(allow_anonymous_access=False) + websocket = _FakeWebSocket(app=app, first_message={"type": "subscribe", "path": "foo"}) + + def fake_get_current_principal(**kwargs): + assert kwargs["access_token"] is None + assert kwargs["api_key"] is None + raise HTTPException(status_code=401, detail="Invalid credentials") + + monkeypatch.setattr(auth, "get_current_principal", fake_get_current_principal) + + principal = await auth.get_current_principal_websocket(websocket=websocket, scopes=["read:monitor"]) + + assert principal is None + assert websocket.accepted is True + assert websocket.state.already_accepted is True + + +@pytest.mark.asyncio +async def test_websocket_query_param_access_token_forwarded(monkeypatch): + app = _make_app(allow_anonymous_access=True) + websocket = _FakeWebSocket(app=app, query_params={"access_token": "query-token"}) + + def fake_get_current_principal(**kwargs): + return kwargs["access_token"] + + monkeypatch.setattr(auth, "get_current_principal", fake_get_current_principal) + + principal = await auth.get_current_principal_websocket(websocket=websocket, scopes=["read:monitor"]) + + assert principal == "query-token" + assert websocket.accepted is False + assert websocket.state.already_accepted is False + + +@pytest.mark.asyncio +async def test_websocket_header_token_takes_precedence_over_query(monkeypatch): + app = _make_app(allow_anonymous_access=True) + websocket = _FakeWebSocket( + app=app, + headers={"Authorization": "Bearer header-token"}, + query_params={"access_token": "query-token"}, + ) + + def fake_get_current_principal(**kwargs): + return kwargs["access_token"] + + monkeypatch.setattr(auth, "get_current_principal", fake_get_current_principal) + + principal = await auth.get_current_principal_websocket(websocket=websocket, scopes=["read:monitor"]) + + assert principal == "header-token" + + +@pytest.mark.asyncio +async def test_websocket_first_message_receive_error_returns_none(monkeypatch): + app = _make_app(allow_anonymous_access=False) + websocket = _FakeWebSocket(app=app, receive_error=TimeoutError("timed out")) + + def fake_get_current_principal(**kwargs): + raise AssertionError("get_current_principal should not be called") + + monkeypatch.setattr(auth, "get_current_principal", fake_get_current_principal) + + principal = await auth.get_current_principal_websocket(websocket=websocket, scopes=["read:monitor"]) + + assert principal is None + assert websocket.accepted is True + assert websocket.state.already_accepted is True