diff --git a/CHANGELOG.md b/CHANGELOG.md index febbef45..d72de1f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ All notable changes to `mcp/sdk` will be documented in this file. * Add optional `title` parameter to `Builder::addResource()` and `Builder::addResourceTemplate()` for MCP spec compliance * [BC Break] `Builder::addResource()` signature changed — `$title` parameter added between `$name` and `$description`. Callers using positional arguments must switch to named arguments. * [BC Break] `Builder::addResourceTemplate()` signature changed — `$title` parameter added between `$name` and `$description`. Callers using positional arguments must switch to named arguments. +* Add `CorsMiddleware`, `DnsRebindingProtectionMiddleware`, and `ProtocolVersionMiddleware` for `StreamableHttpTransport`, composed automatically as the default stack via `StreamableHttpTransport::defaultMiddleware()` +* **[BC BREAK]** `StreamableHttpTransport` constructor: `$corsHeaders` parameter removed; CORS is now configured via `CorsMiddleware`. The `$middleware` parameter is nullable — `null` (or omitted) installs the default stack; `[]` disables all defaults. Default `Access-Control-Allow-Origin` is no longer set (was `*`). 0.5.0 ----- diff --git a/docs/transports.md b/docs/transports.md index a68875d9..3708e350 100644 --- a/docs/transports.md +++ b/docs/transports.md @@ -110,8 +110,8 @@ $transport = new StreamableHttpTransport( - **`request`** (required): `ServerRequestInterface` - The incoming PSR-7 HTTP request - **`responseFactory`** (optional): `ResponseFactoryInterface` - PSR-17 factory for creating HTTP responses. Auto-discovered if not provided. - **`streamFactory`** (optional): `StreamFactoryInterface` - PSR-17 factory for creating response body streams. Auto-discovered if not provided. -- **`corsHeaders`** (optional): `array` - Custom CORS headers to override defaults. Merges with secure defaults. Defaults to `[]`. - **`logger`** (optional): `LoggerInterface` - PSR-3 logger for debugging. Defaults to `NullLogger`. +- **`middleware`** (optional): `iterable|null` - PSR-15 middleware chain. `null` (omitted) installs the [default stack](#default-middleware). `[]` disables all defaults — useful when the surrounding application already handles CORS, host validation, etc. ### PSR-17 Auto-Discovery @@ -137,56 +137,109 @@ $psr17Factory = new Psr17Factory(); $transport = new StreamableHttpTransport($request, $psr17Factory, $psr17Factory); ``` -### CORS Configuration +### Default Middleware + +When the `middleware` argument is omitted (or set to `null`), the transport installs a secure default stack: -The transport sets secure CORS defaults that can be customized or disabled: +| Order | Middleware | Purpose | +|-------|------------|---------| +| 1 | `CorsMiddleware` | Applies CORS headers to every response. By default does **not** set `Access-Control-Allow-Origin` (cross-origin requests are blocked). | +| 2 | `DnsRebindingProtectionMiddleware` | Validates `Origin`/`Host` against an allowlist. Defaults to localhost variants only. | +| 3 | `ProtocolVersionMiddleware` | Rejects requests carrying an unsupported `MCP-Protocol-Version` header with `400 Bad Request`. | ```php -// Default CORS headers (backward compatible) -$transport = new StreamableHttpTransport($request, $responseFactory, $streamFactory); +// Zero-config, secure-by-default — local servers get full protection automatically. +$transport = new StreamableHttpTransport($request); +``` -// Restrict to specific origin -$transport = new StreamableHttpTransport( - $request, - $responseFactory, - $streamFactory, - ['Access-Control-Allow-Origin' => 'https://myapp.com'] -); +The default stack can be inspected and recomposed via the public factory: + +```php +$middleware = StreamableHttpTransport::defaultMiddleware(); +``` + +### CORS Configuration + +CORS is handled by `CorsMiddleware`. To enable cross-origin browser requests, configure it explicitly and pass it +in place of (or alongside) the defaults: -// Disable CORS for proxy scenarios +```php +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; +use Mcp\Server\Transport\Http\Middleware\ProtocolVersionMiddleware; +use Mcp\Server\Transport\StreamableHttpTransport; + +// Reflect a specific origin $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - ['Access-Control-Allow-Origin' => ''] + middleware: [ + new CorsMiddleware(allowedOrigins: ['https://myapp.com']), + new DnsRebindingProtectionMiddleware(), + new ProtocolVersionMiddleware(), + ], ); -// Custom headers with logger +// Allow all origins (development only) $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - [ - 'Access-Control-Allow-Origin' => 'https://api.example.com', - 'Access-Control-Max-Age' => '86400' + middleware: [ + new CorsMiddleware(allowedOrigins: ['*']), + new DnsRebindingProtectionMiddleware(), + new ProtocolVersionMiddleware(), ], - $logger ); ``` -Default CORS headers: -- `Access-Control-Allow-Origin: *` -- `Access-Control-Allow-Methods: GET, POST, DELETE, OPTIONS` -- `Access-Control-Allow-Headers: Content-Type, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID, Authorization, Accept` +When the allowlist is a concrete set of origins (not `['*']`), `CorsMiddleware` automatically adds `Vary: Origin` +so shared caches/CDNs do not serve a response generated for one origin to a request from another. + +Headers already present on a response (e.g. set by inner middleware) are preserved — `CorsMiddleware` only adds +defaults when they are absent. + +> [!IMPORTANT] +> `Access-Control-Allow-Origin: *` is incompatible with credentialed browser requests (those carrying +> `Authorization`, cookies, or client certificates). If your MCP server runs OAuth/Bearer auth and serves +> a browser client, configure `allowedOrigins` with the explicit origin(s) you trust rather than `['*']`. +> The middleware reflects the matching origin verbatim, which is the form browsers accept with credentials. -### PSR-15 Middleware +### DNS Rebinding Protection -`StreamableHttpTransport` can run a PSR-15 middleware chain before it processes the request. Middleware can log, -enforce auth, or short-circuit with a response for any HTTP method. +`DnsRebindingProtectionMiddleware` validates the `Origin` header against an allowlist (falling back to `Host` +when `Origin` is absent). The default allowlist is localhost-only: + +```php +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; + +new DnsRebindingProtectionMiddleware(allowedHosts: ['myapp.local', 'mcp.internal']); +``` + +If the server is fronted by a reverse proxy that already validates `Host`, drop this middleware from the chain +or supply a permissive allowlist. + +### Protocol Version Validation + +`ProtocolVersionMiddleware` rejects requests whose `MCP-Protocol-Version` header is not in the SDK's supported +set with `400 Bad Request`. Requests without the header pass through, since the `initialize` round-trip and some +legacy clients do not send it. + +```php +use Mcp\Schema\Enum\ProtocolVersion; +use Mcp\Server\Transport\Http\Middleware\ProtocolVersionMiddleware; + +// Only accept the latest spec version +new ProtocolVersionMiddleware(supportedVersions: [ProtocolVersion::V2025_11_25]); +``` + +### Custom PSR-15 Middleware + +`StreamableHttpTransport` accepts any PSR-15 middleware chain. To extend the defaults, spread them and append +your own middleware — the defaults stay outermost so CORS headers are applied to every response, including +short-circuited ones: ```php use Mcp\Server\Transport\StreamableHttpTransport; use Psr\Http\Message\ResponseFactoryInterface; +use Psr\Http\Message\ResponseInterface; use Psr\Http\Message\ServerRequestInterface; use Psr\Http\Server\MiddlewareInterface; use Psr\Http\Server\RequestHandlerInterface; @@ -197,7 +250,7 @@ final class AuthMiddleware implements MiddlewareInterface { } - public function process(ServerRequestInterface $request, RequestHandlerInterface $handler) + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface { if (!$request->hasHeader('Authorization')) { return $this->responses->createResponse(401); @@ -209,15 +262,40 @@ final class AuthMiddleware implements MiddlewareInterface $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - [], - $logger, - [new AuthMiddleware($responseFactory)], + logger: $logger, + middleware: [ + ...StreamableHttpTransport::defaultMiddleware(), + new AuthMiddleware($responseFactory), + ], ); ``` -If middleware returns a response, the transport will still ensure CORS headers are present unless you set them yourself. +To selectively drop one default (for example DNS rebinding when running behind a proxy), filter the default list: + +```php +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; +use Mcp\Server\Transport\StreamableHttpTransport; + +$transport = new StreamableHttpTransport( + $request, + middleware: [ + ...array_filter( + StreamableHttpTransport::defaultMiddleware(), + fn ($m) => !$m instanceof DnsRebindingProtectionMiddleware, + ), + new AuthMiddleware($responseFactory), + ], +); +``` + +Pass `middleware: []` to disable every default and run only your own chain: + +```php +$transport = new StreamableHttpTransport( + $request, + middleware: [new AuthMiddleware($responseFactory)], +); +``` ### Architecture diff --git a/examples/server/oauth-keycloak/server.php b/examples/server/oauth-keycloak/server.php index bdd22b90..fdaae7a0 100644 --- a/examples/server/oauth-keycloak/server.php +++ b/examples/server/oauth-keycloak/server.php @@ -58,7 +58,12 @@ $transport = new StreamableHttpTransport( (new Psr17Factory())->createServerRequestFromGlobals(), logger: logger(), - middleware: [$metadataMiddleware, $authMiddleware, new OAuthRequestMetaMiddleware()], + middleware: [ + ...StreamableHttpTransport::defaultMiddleware(), + $metadataMiddleware, + $authMiddleware, + new OAuthRequestMetaMiddleware(), + ], ); $response = $server->run($transport); diff --git a/examples/server/oauth-microsoft/server.php b/examples/server/oauth-microsoft/server.php index 419817cc..c4fae598 100644 --- a/examples/server/oauth-microsoft/server.php +++ b/examples/server/oauth-microsoft/server.php @@ -81,7 +81,13 @@ $transport = new StreamableHttpTransport( (new Psr17Factory())->createServerRequestFromGlobals(), logger: logger(), - middleware: [$oauthProxyMiddleware, $metadataMiddleware, $authMiddleware, new OAuthRequestMetaMiddleware()], + middleware: [ + ...StreamableHttpTransport::defaultMiddleware(), + $oauthProxyMiddleware, + $metadataMiddleware, + $authMiddleware, + new OAuthRequestMetaMiddleware(), + ], ); $response = $server->run($transport); diff --git a/src/Server/Transport/Http/JsonRpcErrorResponse.php b/src/Server/Transport/Http/JsonRpcErrorResponse.php new file mode 100644 index 00000000..592d1793 --- /dev/null +++ b/src/Server/Transport/Http/JsonRpcErrorResponse.php @@ -0,0 +1,41 @@ +createResponse($statusCode) + ->withHeader('Content-Type', 'application/json') + ->withBody($streamFactory->createStream($body)); + } +} diff --git a/src/Server/Transport/Http/Middleware/CorsMiddleware.php b/src/Server/Transport/Http/Middleware/CorsMiddleware.php new file mode 100644 index 00000000..cc11f539 --- /dev/null +++ b/src/Server/Transport/Http/Middleware/CorsMiddleware.php @@ -0,0 +1,152 @@ + + */ +final class CorsMiddleware implements MiddlewareInterface +{ + private readonly bool $isWildcard; + private readonly bool $varyOnOrigin; + private readonly string $allowedMethodsHeader; + private readonly string $allowedHeadersHeader; + private readonly ?string $exposedHeadersHeader; + + /** + * @param list $allowedOrigins Origins permitted for cross-origin requests. Empty disables `Access-Control-Allow-Origin`. Use `['*']` to allow any origin. + * @param list $allowedMethods Methods advertised via `Access-Control-Allow-Methods` (preflight only) + * @param list $allowedHeaders Headers advertised via `Access-Control-Allow-Headers` (preflight only) + * @param list $exposedHeaders Headers advertised via `Access-Control-Expose-Headers` + * @param bool $allowCredentials Whether to emit `Access-Control-Allow-Credentials: true`. Incompatible with `allowedOrigins: ['*']` — combining them throws. + */ + public function __construct( + private readonly array $allowedOrigins = [], + array $allowedMethods = ['GET', 'POST', 'DELETE', 'OPTIONS'], + array $allowedHeaders = [ + 'Accept', + 'Authorization', + 'Content-Type', + 'Last-Event-ID', + StreamableHttpTransport::PROTOCOL_VERSION_HEADER, + StreamableHttpTransport::SESSION_HEADER, + ], + array $exposedHeaders = [StreamableHttpTransport::SESSION_HEADER], + private readonly bool $allowCredentials = false, + ) { + $this->isWildcard = \in_array('*', $allowedOrigins, true); + + if ($this->isWildcard && $allowCredentials) { + throw new InvalidArgumentException('Access-Control-Allow-Origin: * is incompatible with Access-Control-Allow-Credentials: true. Configure an explicit allowedOrigins list when credentialed requests are required.'); + } + + $this->varyOnOrigin = [] !== $allowedOrigins && !$this->isWildcard; + $this->allowedMethodsHeader = implode(', ', $allowedMethods); + $this->allowedHeadersHeader = implode(', ', $allowedHeaders); + $this->exposedHeadersHeader = [] === $exposedHeaders ? null : implode(', ', $exposedHeaders); + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $response = $handler->handle($request); + + $allowedOrigin = $this->resolveAllowedOrigin($request->getHeaderLine('Origin')); + if (null !== $allowedOrigin && !$response->hasHeader('Access-Control-Allow-Origin')) { + $response = $response->withHeader('Access-Control-Allow-Origin', $allowedOrigin); + } + + if ($this->allowCredentials && null !== $allowedOrigin && !$response->hasHeader('Access-Control-Allow-Credentials')) { + $response = $response->withHeader('Access-Control-Allow-Credentials', 'true'); + } + + if ($this->varyOnOrigin) { + $response = $this->ensureVaryOrigin($response); + } + + if ($this->isPreflight($request)) { + if (!$response->hasHeader('Access-Control-Allow-Methods')) { + $response = $response->withHeader('Access-Control-Allow-Methods', $this->allowedMethodsHeader); + } + + if (!$response->hasHeader('Access-Control-Allow-Headers')) { + $response = $response->withHeader('Access-Control-Allow-Headers', $this->allowedHeadersHeader); + } + } + + if (null !== $this->exposedHeadersHeader && !$response->hasHeader('Access-Control-Expose-Headers')) { + $response = $response->withHeader('Access-Control-Expose-Headers', $this->exposedHeadersHeader); + } + + return $response; + } + + private function isPreflight(ServerRequestInterface $request): bool + { + return 'OPTIONS' === $request->getMethod() + && $request->hasHeader('Access-Control-Request-Method'); + } + + private function resolveAllowedOrigin(string $origin): ?string + { + if ([] === $this->allowedOrigins) { + return null; + } + + if ($this->isWildcard) { + return '*'; + } + + if ('' !== $origin && \in_array($origin, $this->allowedOrigins, true)) { + return $origin; + } + + return null; + } + + private function ensureVaryOrigin(ResponseInterface $response): ResponseInterface + { + $current = $response->getHeaderLine('Vary'); + + if ('' === $current) { + return $response->withHeader('Vary', 'Origin'); + } + + if ('*' === trim($current)) { + return $response; + } + + $tokens = array_map('strtolower', array_map('trim', explode(',', $current))); + if (\in_array('origin', $tokens, true)) { + return $response; + } + + return $response->withHeader('Vary', $current.', Origin'); + } +} diff --git a/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php b/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php new file mode 100644 index 00000000..490aa373 --- /dev/null +++ b/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php @@ -0,0 +1,111 @@ + + */ +final class DnsRebindingProtectionMiddleware implements MiddlewareInterface +{ + private ResponseFactoryInterface $responseFactory; + private StreamFactoryInterface $streamFactory; + + /** @var list */ + private readonly array $allowedHosts; + + /** + * @param list $allowedHosts Hostnames (without port) that are permitted. Defaults to localhost variants. + * IPv6 addresses must be bracketed (e.g. `[::1]`) — that is the canonical form returned by `parse_url`. + * @param ResponseFactoryInterface|null $responseFactory PSR-17 response factory (auto-discovered if null) + * @param StreamFactoryInterface|null $streamFactory PSR-17 stream factory (auto-discovered if null) + */ + public function __construct( + array $allowedHosts = ['localhost', '127.0.0.1', '[::1]'], + ?ResponseFactoryInterface $responseFactory = null, + ?StreamFactoryInterface $streamFactory = null, + ) { + $this->allowedHosts = array_values(array_map('strtolower', $allowedHosts)); + $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); + $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $origin = $request->getHeaderLine('Origin'); + if ('' !== $origin) { + if (!$this->isAllowedOrigin($origin)) { + return $this->createForbiddenResponse('Forbidden: Invalid Origin header.'); + } + + return $handler->handle($request); + } + + $host = $request->getHeaderLine('Host'); + if ('' !== $host && !$this->isAllowedHost($host)) { + return $this->createForbiddenResponse('Forbidden: Invalid Host header.'); + } + + return $handler->handle($request); + } + + private function isAllowedOrigin(string $origin): bool + { + $host = parse_url($origin, \PHP_URL_HOST); + if (!\is_string($host) || '' === $host) { + return false; + } + + return \in_array(strtolower($host), $this->allowedHosts, true); + } + + private function isAllowedHost(string $host): bool + { + if (str_starts_with($host, '[')) { + $closingBracket = strpos($host, ']'); + if (false === $closingBracket) { + return false; + } + $hostname = substr($host, 0, $closingBracket + 1); + } else { + $hostname = explode(':', $host, 2)[0]; + } + + return \in_array(strtolower($hostname), $this->allowedHosts, true); + } + + private function createForbiddenResponse(string $message): ResponseInterface + { + return $this->responseFactory + ->createResponse(403) + ->withHeader('Content-Type', 'text/plain') + ->withBody($this->streamFactory->createStream($message)); + } +} diff --git a/src/Server/Transport/Http/Middleware/ProtocolVersionMiddleware.php b/src/Server/Transport/Http/Middleware/ProtocolVersionMiddleware.php new file mode 100644 index 00000000..6d60e1ab --- /dev/null +++ b/src/Server/Transport/Http/Middleware/ProtocolVersionMiddleware.php @@ -0,0 +1,98 @@ + + */ +final class ProtocolVersionMiddleware implements MiddlewareInterface +{ + private ResponseFactoryInterface $responseFactory; + private StreamFactoryInterface $streamFactory; + + /** @var list */ + private readonly array $supportedVersions; + + /** + * @param list|null $supportedVersions Versions the server accepts. Defaults to all values of {@see ProtocolVersion}. + * @param ResponseFactoryInterface|null $responseFactory PSR-17 response factory (auto-discovered if null) + * @param StreamFactoryInterface|null $streamFactory PSR-17 stream factory (auto-discovered if null) + */ + public function __construct( + ?array $supportedVersions = null, + ?ResponseFactoryInterface $responseFactory = null, + ?StreamFactoryInterface $streamFactory = null, + ) { + $versions = $supportedVersions ?? ProtocolVersion::cases(); + $this->supportedVersions = array_values(array_map(static fn (ProtocolVersion $v): string => $v->value, $versions)); + $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); + $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $headerValue = $request->getHeaderLine(StreamableHttpTransport::PROTOCOL_VERSION_HEADER); + + // Spec backwards-compat: when the header is absent, the server SHOULD assume + // protocol version 2025-03-26 — the release in which Streamable HTTP and the + // header itself were introduced. This is deliberately lower than the SDK's + // own default (V2025_06_18) so clients predating the header convention still + // get a deterministic protocol version applied. Servers that whitelist only + // newer versions in $supportedVersions will reject such requests with 400. + $version = '' === $headerValue ? ProtocolVersion::V2025_03_26->value : $headerValue; + + if (\in_array($version, $this->supportedVersions, true)) { + return $handler->handle($request); + } + + $message = '' === $headerValue + ? \sprintf( + 'Missing %s header; backwards-compat default %s is not accepted. Supported versions: %s.', + StreamableHttpTransport::PROTOCOL_VERSION_HEADER, + $version, + implode(', ', $this->supportedVersions), + ) + : \sprintf( + 'Unsupported %s header value: %s. Supported versions: %s.', + StreamableHttpTransport::PROTOCOL_VERSION_HEADER, + $headerValue, + implode(', ', $this->supportedVersions), + ); + + return JsonRpcErrorResponse::create($this->responseFactory, $this->streamFactory, 400, Error::forInvalidParams($message)); + } +} diff --git a/src/Server/Transport/StreamableHttpTransport.php b/src/Server/Transport/StreamableHttpTransport.php index 62e82ae4..ab84b092 100644 --- a/src/Server/Transport/StreamableHttpTransport.php +++ b/src/Server/Transport/StreamableHttpTransport.php @@ -14,6 +14,9 @@ use Http\Discovery\Psr17FactoryDiscovery; use Mcp\Exception\InvalidArgumentException; use Mcp\Schema\JsonRpc\Error; +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; +use Mcp\Server\Transport\Http\Middleware\ProtocolVersionMiddleware; use Mcp\Server\Transport\Http\MiddlewareRequestHandler; use Psr\Http\Message\ResponseFactoryInterface; use Psr\Http\Message\ResponseInterface; @@ -30,16 +33,8 @@ */ class StreamableHttpTransport extends BaseTransport { - private const SESSION_HEADER = 'Mcp-Session-Id'; - - private const ALLOWED_HEADER = [ - 'Accept', - 'Authorization', - 'Content-Type', - 'Last-Event-ID', - 'Mcp-Protocol-Version', - self::SESSION_HEADER, - ]; + public const SESSION_HEADER = 'Mcp-Session-Id'; + public const PROTOCOL_VERSION_HEADER = 'Mcp-Protocol-Version'; private ResponseFactoryInterface $responseFactory; private StreamFactoryInterface $streamFactory; @@ -47,44 +42,48 @@ class StreamableHttpTransport extends BaseTransport private ?string $immediateResponse = null; private ?int $immediateStatusCode = null; - /** @var array */ - private array $corsHeaders; - /** @var list */ - private array $middleware = []; + private array $middleware; /** - * @param array $corsHeaders - * @param iterable $middleware + * @param iterable|null $middleware `null` installs {@see self::defaultMiddleware()}; `[]` disables all middleware */ public function __construct( private ServerRequestInterface $request, ?ResponseFactoryInterface $responseFactory = null, ?StreamFactoryInterface $streamFactory = null, - array $corsHeaders = [], ?LoggerInterface $logger = null, - iterable $middleware = [], + ?iterable $middleware = null, ) { parent::__construct($logger); $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); - $this->corsHeaders = array_merge([ - 'Access-Control-Allow-Origin' => '*', - 'Access-Control-Allow-Methods' => 'GET, POST, DELETE, OPTIONS', - 'Access-Control-Allow-Headers' => implode(',', self::ALLOWED_HEADER), - 'Access-Control-Expose-Headers' => self::SESSION_HEADER, - ], $corsHeaders); - - foreach ($middleware as $m) { - if (!$m instanceof MiddlewareInterface) { - throw new InvalidArgumentException('Streamable HTTP middleware must implement Psr\\Http\\Server\\MiddlewareInterface.'); + if (null === $middleware) { + $this->middleware = self::defaultMiddleware(); + } else { + $this->middleware = self::normalizeMiddleware($middleware); + if ([] === $this->middleware) { + $this->logger->warning('Streamable HTTP transport started with an empty middleware list. Default security protections (CORS, DNS rebinding, protocol version validation) are disabled. Pass null (or omit the argument) to use the secure defaults, or include them via [...StreamableHttpTransport::defaultMiddleware(), $yourMiddleware].'); } - $this->middleware[] = $m; } } + /** + * Secure default middleware stack applied when no `$middleware` is provided to the constructor. + * + * @return list + */ + public static function defaultMiddleware(): array + { + return [ + new CorsMiddleware(), + new DnsRebindingProtectionMiddleware(), + new ProtocolVersionMiddleware(), + ]; + } + public function send(string $data, array $context): void { $this->immediateResponse = $data; @@ -98,7 +97,7 @@ public function listen(): ResponseInterface \Closure::fromCallable([$this, 'handleRequest']), ); - return $this->withCorsHeaders($handler->handle($this->request)); + return $handler->handle($this->request); } protected function handleOptionsRequest(): ResponseInterface @@ -274,15 +273,22 @@ protected function createErrorResponse(Error $jsonRpcError, int $statusCode): Re return $response; } - protected function withCorsHeaders(ResponseInterface $response): ResponseInterface + /** + * @param iterable $middleware + * + * @return list + */ + private static function normalizeMiddleware(iterable $middleware): array { - foreach ($this->corsHeaders as $name => $value) { - if (!$response->hasHeader($name)) { - $response = $response->withHeader($name, $value); + $normalized = []; + foreach ($middleware as $m) { + if (!$m instanceof MiddlewareInterface) { + throw new InvalidArgumentException('Streamable HTTP middleware must implement Psr\\Http\\Server\\MiddlewareInterface.'); } + $normalized[] = $m; } - return $response; + return $normalized; } private function handleRequest(ServerRequestInterface $request): ResponseInterface diff --git a/tests/Conformance/conformance-baseline.yml b/tests/Conformance/conformance-baseline.yml index 61f9783f..efda80ab 100644 --- a/tests/Conformance/conformance-baseline.yml +++ b/tests/Conformance/conformance-baseline.yml @@ -1,5 +1,4 @@ -server: - - dns-rebinding-protection +server: [] client: - elicitation-sep1034-client-defaults diff --git a/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php b/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php new file mode 100644 index 00000000..863b669e --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php @@ -0,0 +1,269 @@ +factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertTrue($response->hasHeader('Access-Control-Expose-Headers')); + // Non-preflight: Methods/Headers must NOT be emitted per CORS spec. + $this->assertFalse($response->hasHeader('Access-Control-Allow-Methods')); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Headers')); + } + + #[TestDox('preflight request receives Access-Control-Allow-Methods and Access-Control-Allow-Headers')] + public function testPreflightReceivesMethodAndHeaderAdvertisements(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->preflightRequest('https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); + $this->assertNotSame('', $response->getHeaderLine('Access-Control-Allow-Headers')); + } + + #[TestDox('non-preflight OPTIONS request does not receive Methods/Headers advertisements')] + public function testPlainOptionsIsNotTreatedAsPreflight(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + // OPTIONS without `Access-Control-Request-Method` is not a CORS preflight. + $request = $this->factory->createServerRequest('OPTIONS', 'https://example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Methods')); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Headers')); + } + + #[TestDox('wildcard allowedOrigins sets Access-Control-Allow-Origin to *')] + public function testWildcardOrigin(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); + } + + #[TestDox('matching Origin is reflected back')] + public function testMatchingOriginIsReflected(): void + { + $middleware = new CorsMiddleware( + allowedOrigins: ['https://app.example.com', 'https://staging.example.com'], + ); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('https://app.example.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + } + + #[TestDox('non-matching Origin is not echoed')] + public function testNonMatchingOriginIsBlocked(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + } + + #[TestDox('does not overwrite headers set by inner middleware')] + public function testPreExistingHeadersAreNotOverwritten(): void + { + $inner = $this->handlerReturning(200, [ + 'Access-Control-Allow-Origin' => 'https://override.example.com', + 'Access-Control-Allow-Methods' => 'POST', + ]); + + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->preflightRequest(); + + $response = $middleware->process($request, $inner); + + $this->assertSame('https://override.example.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertSame('POST', $response->getHeaderLine('Access-Control-Allow-Methods')); + } + + #[TestDox('exposed headers can be omitted')] + public function testEmptyExposedHeadersAreNotSet(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*'], exposedHeaders: []); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Expose-Headers')); + } + + #[TestDox('adds Vary: Origin when reflecting a specific origin to protect caches')] + public function testVaryOriginIsAddedForReflectedOrigin(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('Origin', $response->getHeaderLine('Vary')); + } + + #[TestDox('adds Vary: Origin even when origin is rejected so caches do not poison')] + public function testVaryOriginIsAddedEvenWhenOriginDoesNotMatch(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertSame('Origin', $response->getHeaderLine('Vary')); + } + + #[TestDox('does not add Vary when Access-Control-Allow-Origin is wildcard')] + public function testVaryOriginIsNotAddedForWildcard(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertFalse($response->hasHeader('Vary')); + } + + #[TestDox('does not add Vary when no allowed origins are configured')] + public function testVaryOriginIsNotAddedWhenAllowedOriginsEmpty(): void + { + $middleware = new CorsMiddleware(); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Vary')); + } + + #[TestDox('preserves existing Vary value when appending Origin')] + public function testVaryOriginAppendsToExistingVary(): void + { + $inner = $this->handlerReturning(200, ['Vary' => 'Accept-Encoding']); + + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $inner); + + $this->assertSame('Accept-Encoding, Origin', $response->getHeaderLine('Vary')); + } + + #[TestDox('does not duplicate Origin in existing Vary header')] + public function testVaryOriginIsNotDuplicated(): void + { + $inner = $this->handlerReturning(200, ['Vary' => 'Accept-Encoding, Origin']); + + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $inner); + + $this->assertSame('Accept-Encoding, Origin', $response->getHeaderLine('Vary')); + } + + #[TestDox('does not treat a substring match like Origin-Other as the Origin token')] + public function testVarySubstringDoesNotPreventAppending(): void + { + // `Origin-Resource-Policy` contains the substring "origin" but is a different token — + // tokenized comparison must still treat the response as missing the `Origin` value. + $inner = $this->handlerReturning(200, ['Vary' => 'Origin-Resource-Policy']); + + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $inner); + + $this->assertSame('Origin-Resource-Policy, Origin', $response->getHeaderLine('Vary')); + } + + #[TestDox('allowCredentials emits Access-Control-Allow-Credentials when an origin matches')] + public function testAllowCredentialsHeaderEmitted(): void + { + $middleware = new CorsMiddleware( + allowedOrigins: ['https://app.example.com'], + allowCredentials: true, + ); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('https://app.example.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertSame('true', $response->getHeaderLine('Access-Control-Allow-Credentials')); + } + + #[TestDox('allowCredentials does not emit credentials header when no origin matches')] + public function testAllowCredentialsSkippedWhenOriginUnmatched(): void + { + $middleware = new CorsMiddleware( + allowedOrigins: ['https://app.example.com'], + allowCredentials: true, + ); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Credentials')); + } + + #[TestDox('combining wildcard origin with allowCredentials throws')] + public function testWildcardWithCredentialsRejected(): void + { + $this->expectException(InvalidArgumentException::class); + + new CorsMiddleware(allowedOrigins: ['*'], allowCredentials: true); + } + + private function preflightRequest(string $origin = 'https://app.example.com'): ServerRequestInterface + { + return $this->factory + ->createServerRequest('OPTIONS', 'https://example.com') + ->withHeader('Origin', $origin) + ->withHeader('Access-Control-Request-Method', 'POST') + ->withHeader('Access-Control-Request-Headers', 'Content-Type'); + } +} diff --git a/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php b/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php new file mode 100644 index 00000000..720a0f27 --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php @@ -0,0 +1,145 @@ + ['http://localhost:8000']; + yield 'IPv4 loopback' => ['http://127.0.0.1:3000']; + yield 'IPv6 loopback (bracketed)' => ['http://[::1]:8000']; + } + + #[DataProvider('allowedOriginProvider')] + #[TestDox('allows request with localhost Origin variant: $origin')] + public function testAllowsLocalhostOrigin(string $origin): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', $origin); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('rejects non-allowed Origin with 403')] + public function testRejectsForeignOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(403, $response->getStatusCode()); + $this->assertSame('text/plain', $response->getHeaderLine('Content-Type')); + $this->assertStringContainsString('Origin', (string) $response->getBody()); + } + + #[TestDox('Origin header takes precedence over Host')] + public function testOriginPrecedenceOverHost(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://localhost:8000') + ->withHeader('Host', 'evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('validates Host header when Origin is absent')] + public function testFallbackToHostValidation(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://evil/') + ->withHeader('Host', 'evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(403, $response->getStatusCode()); + } + + #[TestDox('strips port from Host header when validating')] + public function testHostPortIsStripped(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost:8000'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('IPv6 Host with port is parsed correctly')] + public function testIpv6HostWithPort(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', '[::1]:8080'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('custom allowed hosts permit non-localhost names')] + public function testCustomAllowedHosts(): void + { + $middleware = new DnsRebindingProtectionMiddleware( + allowedHosts: ['myapp.local'], + responseFactory: $this->factory, + streamFactory: $this->factory, + ); + $request = $this->factory->createServerRequest('POST', 'http://myapp.local/') + ->withHeader('Origin', 'http://myapp.local:3000'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('host comparison is case-insensitive')] + public function testCaseInsensitive(): void + { + $middleware = new DnsRebindingProtectionMiddleware( + allowedHosts: ['MyApp.Local'], + responseFactory: $this->factory, + streamFactory: $this->factory, + ); + $request = $this->factory->createServerRequest('POST', 'http://myapp.local/') + ->withHeader('Origin', 'http://MYAPP.LOCAL:80'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('request without Origin or Host is allowed')] + public function testNoOriginNoHostPasses(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/')->withoutHeader('Host'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } +} diff --git a/tests/Unit/Server/Transport/Http/Middleware/MiddlewareTestCase.php b/tests/Unit/Server/Transport/Http/Middleware/MiddlewareTestCase.php new file mode 100644 index 00000000..3f5ba2fe --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/MiddlewareTestCase.php @@ -0,0 +1,57 @@ +factory = new Psr17Factory(); + $this->passthroughHandler = $this->handlerReturning(200); + } + + /** + * @param array $headers extra headers to set on the response (already-set CORS headers etc.) + */ + protected function handlerReturning(int $status, array $headers = []): RequestHandlerInterface + { + return new class($this->factory, $status, $headers) implements RequestHandlerInterface { + /** @param array $headers */ + public function __construct( + private ResponseFactoryInterface $factory, + private int $status, + private array $headers, + ) { + } + + public function handle(ServerRequestInterface $request): ResponseInterface + { + $response = $this->factory->createResponse($this->status); + foreach ($this->headers as $name => $value) { + $response = $response->withHeader($name, $value); + } + + return $response; + } + }; + } +} diff --git a/tests/Unit/Server/Transport/Http/Middleware/ProtocolVersionMiddlewareTest.php b/tests/Unit/Server/Transport/Http/Middleware/ProtocolVersionMiddlewareTest.php new file mode 100644 index 00000000..ad216d56 --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/ProtocolVersionMiddlewareTest.php @@ -0,0 +1,104 @@ +factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('rejects missing header when 2025-03-26 backwards-compat default is not in supportedVersions')] + public function testMissingHeaderRejectedByStrictServer(): void + { + $middleware = new ProtocolVersionMiddleware( + supportedVersions: [ProtocolVersion::V2025_11_25], + responseFactory: $this->factory, + streamFactory: $this->factory, + ); + $request = $this->factory->createServerRequest('POST', 'http://localhost/'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(400, $response->getStatusCode()); + } + + #[TestDox('accepts every version declared in the ProtocolVersion enum')] + public function testAcceptsSupportedVersions(): void + { + $middleware = new ProtocolVersionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + + foreach (ProtocolVersion::cases() as $version) { + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, $version->value); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode(), 'Expected '.$version->value.' to be accepted.'); + } + } + + #[TestDox('rejects unsupported well-formed version with 400')] + public function testRejectsUnsupportedVersion(): void + { + $middleware = new ProtocolVersionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, '1900-01-01'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(400, $response->getStatusCode()); + $this->assertSame('application/json', $response->getHeaderLine('Content-Type')); + } + + #[TestDox('rejects malformed version with 400')] + public function testRejectsMalformedVersion(): void + { + $middleware = new ProtocolVersionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, 'not-a-version'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(400, $response->getStatusCode()); + } + + #[TestDox('accepts only the supportedVersions whitelist when provided')] + public function testRestrictedSupportedVersions(): void + { + $middleware = new ProtocolVersionMiddleware( + supportedVersions: [ProtocolVersion::V2025_11_25], + responseFactory: $this->factory, + streamFactory: $this->factory, + ); + + $accepted = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, ProtocolVersion::V2025_11_25->value); + $rejected = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, ProtocolVersion::V2024_11_05->value); + + $this->assertSame(200, $middleware->process($accepted, $this->passthroughHandler)->getStatusCode()); + $this->assertSame(400, $middleware->process($rejected, $this->passthroughHandler)->getStatusCode()); + } +} diff --git a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php index 7d9cd484..497338fc 100644 --- a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php +++ b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php @@ -11,9 +11,12 @@ namespace Mcp\Tests\Unit\Server\Transport; +use Mcp\Exception\InvalidArgumentException; +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; +use Mcp\Server\Transport\Http\Middleware\ProtocolVersionMiddleware; use Mcp\Server\Transport\StreamableHttpTransport; use Nyholm\Psr7\Factory\Psr17Factory; -use PHPUnit\Framework\Attributes\DataProvider; use PHPUnit\Framework\Attributes\TestDox; use PHPUnit\Framework\TestCase; use Psr\Http\Message\ResponseFactoryInterface; @@ -21,120 +24,201 @@ use Psr\Http\Message\ServerRequestInterface; use Psr\Http\Server\MiddlewareInterface; use Psr\Http\Server\RequestHandlerInterface; +use Psr\Log\LoggerInterface; final class StreamableHttpTransportTest extends TestCase { - public static function corsHeaderProvider(): iterable + private Psr17Factory $factory; + + protected function setUp(): void { - yield 'GET (middleware returns 401)' => ['GET', false, 401]; - yield 'POST (middleware returns 401)' => ['POST', false, 401]; - yield 'DELETE (middleware returns 401)' => ['DELETE', false, 401]; - yield 'OPTIONS (middleware delegates -> transport handles preflight)' => ['OPTIONS', true, 204]; - yield 'GET (middleware delegates -> transport handles preflight)' => ['GET', true, 405]; - yield 'POST (middleware delegates -> transport handles preflight)' => ['POST', true, 202]; - yield 'DELETE (middleware delegates -> transport handles preflight)' => ['DELETE', true, 400]; + $this->factory = new Psr17Factory(); } - #[DataProvider('corsHeaderProvider')] - #[TestDox('CORS headers are always applied')] - public function testCorsHeader(string $method, bool $middlewareDelegatesToTransport, int $expectedStatusCode): void + #[TestDox('default middleware is applied when none is passed')] + public function testDefaultMiddlewareIsAppliedWhenOmitted(): void { - $factory = new Psr17Factory(); - $request = $factory->createServerRequest($method, 'https://example.com'); - - $middleware = new class($factory, $expectedStatusCode, $middlewareDelegatesToTransport) implements MiddlewareInterface { - public function __construct( - private ResponseFactoryInterface $responseFactory, - private int $expectedStatusCode, - private bool $middlewareDelegatesToTransport, - ) { - } + // Preflight: OPTIONS + Access-Control-Request-Method — CorsMiddleware advertises Methods/Headers only on preflight. + $request = $this->factory + ->createServerRequest('OPTIONS', 'http://localhost/') + ->withHeader('Host', 'localhost') + ->withHeader('Access-Control-Request-Method', 'POST'); - public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface - { - if ($this->middlewareDelegatesToTransport) { - return $handler->handle($request); - } + $transport = new StreamableHttpTransport($request, $this->factory, $this->factory); - return $this->responseFactory->createResponse($this->expectedStatusCode); - } - }; + $response = $transport->listen(); + + $this->assertSame(204, $response->getStatusCode()); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); // secure-by-default + $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); + $this->assertNotSame('', $response->getHeaderLine('Access-Control-Allow-Headers')); + $this->assertNotSame('', $response->getHeaderLine('Access-Control-Expose-Headers')); + } + + #[TestDox('default middleware blocks non-localhost Origin')] + public function testDefaultMiddlewareBlocksRebindingAttempt(): void + { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost') + ->withHeader('Origin', 'http://evil.example.com'); + + $transport = new StreamableHttpTransport($request, $this->factory, $this->factory); + + $response = $transport->listen(); + + $this->assertSame(403, $response->getStatusCode()); + } + + #[TestDox('default middleware rejects unsupported MCP-Protocol-Version')] + public function testDefaultMiddlewareRejectsUnsupportedProtocolVersion(): void + { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, '1900-01-01'); + + $transport = new StreamableHttpTransport($request, $this->factory, $this->factory); + + $response = $transport->listen(); + + $this->assertSame(400, $response->getStatusCode()); + } + + #[TestDox('explicit empty middleware list disables defaults and emits a warning log')] + public function testEmptyMiddlewareListDisablesDefaultsAndWarns(): void + { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'evil.example.com') + ->withHeader('Origin', 'http://evil.example.com'); + + $logger = $this->createMock(LoggerInterface::class); + $logger->expects($this->once()) + ->method('warning') + ->with($this->stringContains('empty middleware list')); $transport = new StreamableHttpTransport( $request, - $factory, - $factory, + $this->factory, + $this->factory, + $logger, [], - null, - [$middleware], ); $response = $transport->listen(); - $this->assertSame($expectedStatusCode, $response->getStatusCode(), $response->getBody()->getContents()); - $this->assertTrue($response->hasHeader('Access-Control-Allow-Origin')); - $this->assertTrue($response->hasHeader('Access-Control-Allow-Methods')); - $this->assertTrue($response->hasHeader('Access-Control-Allow-Headers')); - $this->assertTrue($response->hasHeader('Access-Control-Expose-Headers')); - - $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); - $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); - $this->assertSame( - 'Accept,Authorization,Content-Type,Last-Event-ID,Mcp-Protocol-Version,Mcp-Session-Id', - $response->getHeaderLine('Access-Control-Allow-Headers') - ); - $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); + // No CORS, no DNS rebinding check — transport just answers. + $this->assertNotSame(403, $response->getStatusCode()); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Methods')); } - #[TestDox('transport replaces existing CORS headers on the response')] - public function testCorsHeadersAreReplacedWhenAlreadyPresent(): void + #[TestDox('null middleware does not trigger the empty-list warning')] + public function testNullMiddlewareDoesNotWarn(): void { - $factory = new Psr17Factory(); - $request = $factory->createServerRequest('GET', 'https://example.com'); + $request = $this->factory + ->createServerRequest('OPTIONS', 'http://localhost/') + ->withHeader('Host', 'localhost'); - $middleware = new class($factory) implements MiddlewareInterface { - public function __construct(private ResponseFactoryInterface $responses) - { - } + $logger = $this->createMock(LoggerInterface::class); + $logger->expects($this->never())->method('warning'); - public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface - { - return $this->responses->createResponse(200) - ->withHeader('Access-Control-Allow-Origin', 'https://another.com'); - } - }; + $transport = new StreamableHttpTransport($request, $this->factory, $this->factory, $logger); + $transport->listen(); + } + + #[TestDox('custom middleware composes with default stack via spread')] + public function testDefaultsCanBeSpreadAndExtended(): void + { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost'); $transport = new StreamableHttpTransport( $request, - $factory, - $factory, - [], + $this->factory, + $this->factory, null, - [$middleware], + [ + ...StreamableHttpTransport::defaultMiddleware(), + $this->stubAuth401(), + ], ); $response = $transport->listen(); - $this->assertSame(200, $response->getStatusCode()); + $this->assertSame(401, $response->getStatusCode()); + // CORS middleware is outermost — Expose-Headers is emitted on all responses, including 401. + $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); + } + + #[TestDox('defaults can be filtered to drop DNS rebinding for proxy deployments')] + public function testDefaultsCanBeFilteredToDropDnsRebinding(): void + { + // Behind a reverse proxy: real Host is api.myapp.com, browser Origin is myapp.com. + // DnsRebindingProtectionMiddleware default (localhost-only) would 403 this — drop it. + $request = $this->factory + ->createServerRequest('POST', 'http://api.myapp.com/') + ->withHeader('Host', 'api.myapp.com') + ->withHeader('Origin', 'https://myapp.com'); - $this->assertSame('https://another.com', $response->getHeaderLine('Access-Control-Allow-Origin')); - $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); - $this->assertSame( - 'Accept,Authorization,Content-Type,Last-Event-ID,Mcp-Protocol-Version,Mcp-Session-Id', - $response->getHeaderLine('Access-Control-Allow-Headers') + $transport = new StreamableHttpTransport( + $request, + $this->factory, + $this->factory, + null, + [ + ...array_filter( + StreamableHttpTransport::defaultMiddleware(), + static fn (MiddlewareInterface $m): bool => !$m instanceof DnsRebindingProtectionMiddleware, + ), + $this->stubAuth401(), + ], ); + + $response = $transport->listen(); + + // Auth short-circuits with 401 — proves DNS rebinding didn't reject the request first. + $this->assertSame(401, $response->getStatusCode()); + // CORS middleware is still in the chain — Expose-Headers attached to the 401. $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); } + #[TestDox('configured CorsMiddleware reflects matching Origin')] + public function testConfiguredCorsReflectsMatchingOrigin(): void + { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost') + ->withHeader('Origin', 'https://myapp.example.com'); + + $transport = new StreamableHttpTransport( + $request, + $this->factory, + $this->factory, + null, + [ + new CorsMiddleware(allowedOrigins: ['https://myapp.example.com']), + new DnsRebindingProtectionMiddleware(allowedHosts: ['localhost']), + new ProtocolVersionMiddleware(), + ], + ); + + $response = $transport->listen(); + + $this->assertSame('https://myapp.example.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + } + #[TestDox('middleware runs before transport handles the request')] public function testMiddlewareRunsBeforeTransportHandlesRequest(): void { - $factory = new Psr17Factory(); - $request = $factory->createServerRequest('OPTIONS', 'https://example.com'); + $request = $this->factory->createServerRequest('OPTIONS', 'http://localhost/') + ->withHeader('Host', 'localhost'); $state = new \stdClass(); $state->called = false; - $middleware = new class($state) implements MiddlewareInterface { + $spy = new class($state) implements MiddlewareInterface { public function __construct(private \stdClass $state) { } @@ -149,11 +233,10 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $transport = new StreamableHttpTransport( $request, - $factory, - $factory, - [], + $this->factory, + $this->factory, null, - [$middleware], + [$spy], ); $response = $transport->listen(); @@ -161,4 +244,34 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $this->assertTrue($state->called); $this->assertSame(204, $response->getStatusCode()); } + + #[TestDox('non-middleware entries are rejected')] + public function testInvalidMiddlewareEntryThrows(): void + { + $request = $this->factory->createServerRequest('POST', 'http://localhost/'); + + $this->expectException(InvalidArgumentException::class); + + new StreamableHttpTransport( + $request, + $this->factory, + $this->factory, + null, + [new \stdClass()], // @phpstan-ignore-line argument.type + ); + } + + private function stubAuth401(): MiddlewareInterface + { + return new class($this->factory) implements MiddlewareInterface { + public function __construct(private ResponseFactoryInterface $factory) + { + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + return $this->factory->createResponse(401); + } + }; + } }