Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 105 additions & 34 deletions bluesky_httpserver/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid as uuid_module
import warnings
from datetime import datetime, timedelta
from typing import Optional
from typing import Any, Optional

from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket
from fastapi.openapi.models import APIKey, APIKeyIn
Expand Down Expand Up @@ -36,6 +36,7 @@
from .database import orm
from .database.core import (
create_user,
get_or_create_principal,
latest_principal_activity,
lookup_valid_api_key,
lookup_valid_pending_session_by_device_code,
Expand Down Expand Up @@ -140,28 +141,53 @@ def create_refresh_token(session_id, secret_key, expires_delta):
return encoded_jwt


def decode_token(token, secret_keys):
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def _decode_token_with_secret_keys(token, secret_keys):
# The first key in settings.secret_keys is used for *encoding*.
# All keys are tried for *decoding* until one works or they all
# fail. They supports key rotation.
# fail. They support key rotation.
for secret_key in secret_keys:
try:
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
break
return payload
except ExpiredSignatureError:
# Do not let this be caught below with the other JWTError types.
raise
except JWTError:
# Try the next key in the key rotation.
continue
else:
raise credentials_exception
return payload
return None


def decode_token(token, secret_keys, proxied_authenticator=None):
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
payload = _decode_token_with_secret_keys(token, secret_keys)
if payload is not None:
return payload
if proxied_authenticator is not None:
return proxied_authenticator.decode_token(token)
raise credentials_exception


def _extract_scopes(decoded_access_token: dict[str, Any]) -> set[str]:
if "scp" in decoded_access_token:
scp = decoded_access_token["scp"]
return set(scp) if isinstance(scp, list) else set(scp.split(" "))
if "scope" in decoded_access_token:
return set(decoded_access_token["scope"].split(" "))
return set()


def _get_proxied_authenticator(authenticators):
if not authenticators:
return None
for authenticator in authenticators.values():
if hasattr(authenticator, "oauth2_schema") and hasattr(authenticator, "decode_token"):
return authenticator
return None


async def get_api_key(
Expand Down Expand Up @@ -280,27 +306,49 @@ def get_current_principal(
request.state.cookies_to_set.append({"key": API_KEY_COOKIE_NAME, "value": api_key})
elif access_token is not None:
try:
payload = decode_token(access_token, settings.secret_keys)
payload = decode_token(
access_token,
settings.secret_keys,
_get_proxied_authenticator(authenticators),
)
except ExpiredSignatureError:
raise HTTPException(
status_code=401,
detail="Access token has expired. Refresh token.",
headers=headers_for_401,
)
principal = schemas.Principal(
uuid=uuid_module.UUID(hex=payload["sub"]),
type=payload["sub_typ"],
identities=[
schemas.Identity(id=identity["id"], provider=identity["idp"]) for identity in payload["ids"]
],
)

# scopes = payload["scp"]

# Combine scopes for all identities (it is expected to be only one identity).
ids = [_["id"] for _ in payload["ids"] if _["idp"] in settings.authentication_provider_names]
scopes = set.union(*[api_access_manager.get_user_scopes(_) for _ in ids])
token_scopes = _extract_scopes(payload)
if "sub_typ" in payload and "ids" in payload:
principal = schemas.Principal(
uuid=uuid_module.UUID(hex=payload["sub"]),
type=payload["sub_typ"],
identities=[
schemas.Identity(id=identity["id"], provider=identity["idp"]) for identity in payload["ids"]
],
)
ids = [
_["id"]
for _ in payload["ids"]
if (_["idp"] in settings.authentication_provider_names)
and api_access_manager.is_user_known(_["id"])
]
else:
identity_id = payload.get("user") or payload["sub"]
provider = (
settings.authentication_provider_names[0]
if settings.authentication_provider_names
else _DEFAULT_ANONYMOUS_PROVIDER_NAME
)
with get_sessionmaker(settings.database_settings)() as db:
principal_orm = get_or_create_principal(db, provider, identity_id)
principal = schemas.Principal(
uuid=principal_orm.uuid,
type="user",
identities=[schemas.Identity(id=identity_id, provider=provider)],
)
ids = [identity_id] if api_access_manager.is_user_known(identity_id) else []

scopes = set.union(*[api_access_manager.get_user_scopes(_) for _ in ids]) if ids else set(token_scopes)
roles_sets = [api_access_manager.get_user_roles(_) for _ in ids]
roles = set.union(*roles_sets) if roles_sets else set()

Expand Down Expand Up @@ -361,7 +409,7 @@ def get_current_principal(
return principal


def get_current_principal_websocket(
async def get_current_principal_websocket(
websocket: WebSocket,
scopes: str,
):
Expand All @@ -373,13 +421,31 @@ def get_current_principal_websocket(

auth_header = websocket.headers.get("Authorization", "")
access_token, api_key = None, None
# Currently we do not support authentication with tokens
# if auth_header.startswith("Bearer "):
# access_token = auth_header[len("Bearer") :].strip()
if auth_header.startswith("ApiKey "):
api_key = auth_header[len("ApiKey") :].strip()
scheme, param = get_authorization_scheme_param(auth_header)
if scheme.lower() == "bearer":
access_token = param
elif scheme.lower() == "apikey":
api_key = param

if access_token is None:
access_token = websocket.query_params.get("access_token")
if api_key is None:
api_key = websocket.query_params.get("api_key")

principal = None
websocket.state.already_accepted = False
no_credentials = (access_token is None) and (api_key is None)
if no_credentials and not settings.allow_anonymous_access:
try:
await websocket.accept()
websocket.state.already_accepted = True
message = await asyncio.wait_for(websocket.receive_json(), timeout=1)
if isinstance(message, dict) and message.get("type") == "auth":
access_token = message.get("access_token")
api_key = message.get("api_key")
except Exception:
return None

try:
principal = get_current_principal(
request=websocket,
Expand Down Expand Up @@ -558,11 +624,14 @@ async def authorize_redirect(
"""Redirect browser to OAuth provider for authentication."""
redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code"

requested_scopes = {"openid", "offline_access"}
requested_scopes.update(getattr(authenticator, "extra_scopes", []))
params = {
"client_id": authenticator.client_id,
"response_type": "code",
"scope": "openid profile email",
"scope": " ".join(sorted(requested_scopes)),
"redirect_uri": redirect_uri,
"prompt": "login",
}
if state:
params["state"] = state
Expand Down Expand Up @@ -595,7 +664,9 @@ async def device_code_authorize(
params={
"client_id": authenticator.client_id,
"response_type": "code",
"scope": "openid profile email",
"scope": " ".join(
sorted({"openid", "offline_access", *getattr(authenticator, "extra_scopes", [])})
),
"redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code",
"state": pending_session["user_code"].replace("-", ""),
}
Expand Down
2 changes: 2 additions & 0 deletions bluesky_httpserver/authentication/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .._authentication import (
_extract_scopes,
base_authentication_router,
build_auth_code_route,
build_authorize_route,
Expand All @@ -21,6 +22,7 @@
"ExternalAuthenticator",
"InternalAuthenticator",
"UserSessionState",
"_extract_scopes",
"get_current_principal",
"get_current_principal_websocket",
"base_authentication_router",
Expand Down
Loading
Loading