diff --git a/examples/browser_routing.py b/examples/browser_routing.py new file mode 100644 index 00000000..68627ab9 --- /dev/null +++ b/examples/browser_routing.py @@ -0,0 +1,25 @@ +"""Example: direct-to-VM browser routing for process exec and raw HTTP.""" + +from typing import Any, cast + +import httpx + +from kernel import Kernel + + +def main() -> None: + with Kernel() as client: + browsers = cast(Any, client.browsers) + browser = browsers.create(headless=True) + try: + response = cast(httpx.Response, browsers.request(browser.session_id, "GET", "https://example.com")) + print("status", response.status_code) + + with browsers.stream(browser.session_id, "GET", "https://example.com") as streamed: + print("streamed-bytes", len(streamed.read())) + finally: + browsers.delete_by_id(browser.session_id) + + +if __name__ == "__main__": + main() diff --git a/src/kernel/_client.py b/src/kernel/_client.py index 75fe4b64..e70bf1c4 100644 --- a/src/kernel/_client.py +++ b/src/kernel/_client.py @@ -29,6 +29,13 @@ SyncAPIClient, AsyncAPIClient, ) +from .lib.browser_routing.routing import ( + BrowserRouteCache, + BrowserRoutingConfig, + strip_direct_vm_auth, + rewrite_direct_vm_options, + browser_routing_config_from_env, +) if TYPE_CHECKING: from .resources import ( @@ -79,8 +86,10 @@ class Kernel(SyncAPIClient): # client options api_key: str + browser_route_cache: BrowserRouteCache _environment: Literal["production", "development"] | NotGiven + _browser_routing: BrowserRoutingConfig def __init__( self, @@ -105,6 +114,7 @@ def __init__( # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + _browser_route_cache: BrowserRouteCache | None = None, ) -> None: """Construct a new synchronous Kernel client instance. @@ -154,6 +164,8 @@ def __init__( custom_query=default_query, _strict_response_validation=_strict_response_validation, ) + self.browser_route_cache = _browser_route_cache or BrowserRouteCache() + self._browser_routing = browser_routing_config_from_env() @cached_property def deployments(self) -> DeploymentsResource: @@ -266,6 +278,15 @@ def default_headers(self) -> dict[str, str | Omit]: **self._custom_headers, } + @override + def _prepare_options(self, options: Any) -> Any: + options = cast(Any, super()._prepare_options(options)) + return rewrite_direct_vm_options(options, cache=self.browser_route_cache, config=self._browser_routing) + + @override + def _prepare_request(self, request: httpx.Request) -> None: + strip_direct_vm_auth(request, cache=self.browser_route_cache) + def copy( self, *, @@ -279,6 +300,7 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, + _browser_route_cache: BrowserRouteCache | None = None, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: """ @@ -312,6 +334,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + _browser_route_cache=_browser_route_cache or self.browser_route_cache, **_extra_kwargs, ) @@ -356,8 +379,10 @@ def _make_status_error( class AsyncKernel(AsyncAPIClient): # client options api_key: str + browser_route_cache: BrowserRouteCache _environment: Literal["production", "development"] | NotGiven + _browser_routing: BrowserRoutingConfig def __init__( self, @@ -382,6 +407,7 @@ def __init__( # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + _browser_route_cache: BrowserRouteCache | None = None, ) -> None: """Construct a new async AsyncKernel client instance. @@ -431,6 +457,8 @@ def __init__( custom_query=default_query, _strict_response_validation=_strict_response_validation, ) + self.browser_route_cache = _browser_route_cache or BrowserRouteCache() + self._browser_routing = browser_routing_config_from_env() @cached_property def deployments(self) -> AsyncDeploymentsResource: @@ -543,6 +571,15 @@ def default_headers(self) -> dict[str, str | Omit]: **self._custom_headers, } + @override + async def _prepare_options(self, options: Any) -> Any: + options = cast(Any, await super()._prepare_options(options)) + return rewrite_direct_vm_options(options, cache=self.browser_route_cache, config=self._browser_routing) + + @override + async def _prepare_request(self, request: httpx.Request) -> None: + strip_direct_vm_auth(request, cache=self.browser_route_cache) + def copy( self, *, @@ -556,6 +593,7 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, + _browser_route_cache: BrowserRouteCache | None = None, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: """ @@ -589,6 +627,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + _browser_route_cache=_browser_route_cache or self.browser_route_cache, **_extra_kwargs, ) diff --git a/src/kernel/lib/browser_routing/__init__.py b/src/kernel/lib/browser_routing/__init__.py new file mode 100644 index 00000000..bdec2fc8 --- /dev/null +++ b/src/kernel/lib/browser_routing/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +__all__: list[str] = [] diff --git a/src/kernel/lib/browser_routing/raw_http.py b/src/kernel/lib/browser_routing/raw_http.py new file mode 100644 index 00000000..5e644e39 --- /dev/null +++ b/src/kernel/lib/browser_routing/raw_http.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import IO, Any, Union, Mapping, cast +from contextlib import contextmanager, asynccontextmanager +from collections.abc import Iterable, Iterator, AsyncIterator + +import httpx + +from .util import sanitize_curl_raw_params +from .routing import BrowserRoute +from ..._types import Body, Timeout, NotGiven, not_given +from ..._models import FinalRequestOptions + +BrowserRawContent = Union[bytes, bytearray, memoryview, str, IO[bytes], Iterable[bytes]] + + +def request_via_browser_route( + parent: Any, + route: BrowserRoute, + method: str, + url: str, + *, + content: BrowserRawContent | None = None, + json: Body | None = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, object] | None = None, + timeout: float | Timeout | None | NotGiven = not_given, +) -> httpx.Response: + if json is not None and content is not None: + raise TypeError("Passing both `json` and `content` is not supported") + query: dict[str, object] = {**sanitize_curl_raw_params(params), "url": url, "jwt": route.jwt} + options = FinalRequestOptions.construct( + method=method.upper(), + url=route.base_url.rstrip("/") + "/curl/raw", + params=query, + headers=headers or {}, + content=_normalize_binary_content(content), + json_data=json, + timeout=_normalize_timeout(timeout), + ) + return cast(httpx.Response, parent.request(httpx.Response, options)) + + +@contextmanager +def stream_via_browser_route( + parent: Any, + route: BrowserRoute, + method: str, + url: str, + *, + content: BrowserRawContent | None = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, object] | None = None, + timeout: float | Timeout | None | NotGiven = not_given, +) -> Iterator[httpx.Response]: + query: dict[str, Any] = sanitize_curl_raw_params(params) + query["jwt"] = route.jwt + query["url"] = url + request_headers = {k: v for k, v in parent.default_headers.items() if isinstance(v, str)} + if content is None: + request_headers.pop("Content-Type", None) + if headers: + request_headers.update(headers) + request_headers.pop("Authorization", None) + effective_timeout = parent.timeout if isinstance(timeout, NotGiven) else timeout + with parent._client.stream( + method.upper(), + route.base_url.rstrip("/") + "/curl/raw", + params=query, + headers=request_headers, + content=_normalize_binary_content(content), + timeout=_normalize_timeout(effective_timeout), + ) as response: + yield response + + +async def async_request_via_browser_route( + parent: Any, + route: BrowserRoute, + method: str, + url: str, + *, + content: BrowserRawContent | None = None, + json: Body | None = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, object] | None = None, + timeout: float | Timeout | None | NotGiven = not_given, +) -> httpx.Response: + if json is not None and content is not None: + raise TypeError("Passing both `json` and `content` is not supported") + query: dict[str, object] = {**sanitize_curl_raw_params(params), "url": url, "jwt": route.jwt} + options = FinalRequestOptions.construct( + method=method.upper(), + url=route.base_url.rstrip("/") + "/curl/raw", + params=query, + headers=headers or {}, + content=_normalize_binary_content(content), + json_data=json, + timeout=_normalize_timeout(timeout), + ) + return cast(httpx.Response, await parent.request(httpx.Response, options)) + + +@asynccontextmanager +async def async_stream_via_browser_route( + parent: Any, + route: BrowserRoute, + method: str, + url: str, + *, + content: BrowserRawContent | None = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, object] | None = None, + timeout: float | Timeout | None | NotGiven = not_given, +) -> AsyncIterator[httpx.Response]: + query: dict[str, Any] = sanitize_curl_raw_params(params) + query["jwt"] = route.jwt + query["url"] = url + request_headers = {k: v for k, v in parent.default_headers.items() if isinstance(v, str)} + if content is None: + request_headers.pop("Content-Type", None) + if headers: + request_headers.update(headers) + request_headers.pop("Authorization", None) + effective_timeout = parent.timeout if isinstance(timeout, NotGiven) else timeout + async with parent._client.stream( + method.upper(), + route.base_url.rstrip("/") + "/curl/raw", + params=query, + headers=request_headers, + content=_normalize_binary_content(content), + timeout=_normalize_timeout(effective_timeout), + ) as response: + yield response + + +def _normalize_timeout(timeout: float | Timeout | None | NotGiven) -> float | Timeout | None: + return None if isinstance(timeout, NotGiven) else timeout + + +def _normalize_binary_content(content: BrowserRawContent | None) -> bytes | IO[bytes] | Iterable[bytes] | None: + if content is None: + return None + if isinstance(content, str): + return content.encode() + if isinstance(content, bytearray): + return bytes(content) + if isinstance(content, memoryview): + return content.tobytes() + return content diff --git a/src/kernel/lib/browser_routing/routing.py b/src/kernel/lib/browser_routing/routing.py new file mode 100644 index 00000000..ff1a8dfc --- /dev/null +++ b/src/kernel/lib/browser_routing/routing.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import os +from typing import Any +from dataclasses import field, dataclass + +import httpx + +from .util import ( + jwt_from_cdp_ws_url, + base_url_from_browser_like, + cdp_ws_url_from_browser_like, + session_id_from_browser_like, +) +from ..._compat import model_copy +from ..._models import FinalRequestOptions + + +@dataclass +class BrowserRoute: + session_id: str + base_url: str + jwt: str + + +@dataclass +class BrowserRoutingConfig: + subresources: tuple[str, ...] = field(default_factory=tuple) + + +def browser_routing_config_from_env() -> BrowserRoutingConfig: + raw = os.environ.get("KERNEL_BROWSER_ROUTING_SUBRESOURCES") + if raw is None: + return BrowserRoutingConfig(subresources=("curl",)) + if raw.strip() == "": + return BrowserRoutingConfig() + + return BrowserRoutingConfig(subresources=tuple(part.strip() for part in raw.split(",") if part.strip())) + + +class BrowserRouteCache: + def __init__(self) -> None: + self._routes: dict[str, BrowserRoute] = {} + + def get(self, session_id: str) -> BrowserRoute | None: + return self._routes.get(_normalize_session_id(session_id)) + + def set(self, route: BrowserRoute) -> None: + normalized_session_id = _normalize_session_id(route.session_id) + self._routes[normalized_session_id] = BrowserRoute( + session_id=normalized_session_id, + base_url=route.base_url.strip().rstrip("/") + "/", + jwt=route.jwt.strip(), + ) + + def delete(self, session_id: str) -> None: + self._routes.pop(_normalize_session_id(session_id), None) + + def values(self) -> list[BrowserRoute]: + return list(self._routes.values()) + + +def browser_route_from_browser(browser: Any) -> BrowserRoute | None: + try: + session_id = session_id_from_browser_like(browser) + except TypeError: + return None + + base_url = base_url_from_browser_like(browser) + if not base_url: + return None + + jwt = None + try: + jwt = jwt_from_cdp_ws_url(cdp_ws_url_from_browser_like(browser)) + except Exception: + jwt = None + if not jwt: + return None + + return BrowserRoute(session_id=session_id, base_url=base_url, jwt=jwt) + + +def _normalize_session_id(session_id: str) -> str: + return session_id.strip() + + +def rewrite_direct_vm_options( + options: FinalRequestOptions, + *, + cache: BrowserRouteCache, + config: BrowserRoutingConfig, +) -> FinalRequestOptions: + match = match_direct_vm_path(options.url) + if match is None: + return options + + session_id, subresource, suffix = match + if subresource not in set(config.subresources): + return options + + route = cache.get(session_id) + if route is None: + return options + + rewritten = model_copy(options) + rewritten.url = f"{route.base_url.rstrip('/')}/{subresource}{suffix}" + + params: dict[str, object] = {} + params.update(options.params) + params["jwt"] = route.jwt + rewritten.params = params or options.params + return rewritten + + +def strip_direct_vm_auth(request: httpx.Request, *, cache: BrowserRouteCache) -> None: + raw = str(request.url) + for route in cache.values(): + if raw.startswith(route.base_url.rstrip("/") + "/"): + request.headers.pop("Authorization", None) + return + + +def match_direct_vm_path(path: str) -> tuple[str, str, str] | None: + if "://" in path: + return None + + parts = [part for part in path.strip("/").split("/") if part] + for index in range(len(parts) - 2): + if parts[index] != "browsers": + continue + session_id = parts[index + 1] + subresource = parts[index + 2] + if not session_id or not subresource: + return None + suffix = "" + if index + 3 < len(parts): + suffix = "/" + "/".join(parts[index + 3 :]) + return session_id, subresource, suffix + return None diff --git a/src/kernel/lib/browser_routing/util.py b/src/kernel/lib/browser_routing/util.py new file mode 100644 index 00000000..ecfb7331 --- /dev/null +++ b/src/kernel/lib/browser_routing/util.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import Any, Mapping, cast +from urllib.parse import parse_qs, urlparse + +# Query keys reserved for /curl/raw; user-supplied `params` must not override these. +CURL_RAW_RESERVED_QUERY_KEYS: frozenset[str] = frozenset({"url", "jwt"}) + + +def sanitize_curl_raw_params(params: Mapping[str, object] | None) -> dict[str, object]: + """Drop reserved keys from user params so they cannot override the target URL or auth.""" + if not params: + return {} + return {k: v for k, v in dict(params).items() if k not in CURL_RAW_RESERVED_QUERY_KEYS} + + +def jwt_from_cdp_ws_url(cdp_ws_url: str) -> str | None: + parsed = urlparse(cdp_ws_url) + values = parse_qs(parsed.query).get("jwt") + if not values: + return None + return values[0] + + +def session_id_from_browser_like(browser: Any) -> str: + sid = getattr(browser, "session_id", None) + if isinstance(sid, str) and sid: + return sid + if isinstance(browser, Mapping): + mapping = cast(Mapping[str, object], browser) + value = mapping.get("session_id") + if isinstance(value, str) and value: + return value + raise TypeError("browser object must have a non-empty session_id") + + +def base_url_from_browser_like(browser: Any) -> str | None: + base_url = getattr(browser, "base_url", None) + if isinstance(base_url, str) and base_url.strip(): + return base_url.strip().rstrip("/") + "/" + if isinstance(browser, Mapping): + mapping = cast(Mapping[str, object], browser) + value = mapping.get("base_url") + if isinstance(value, str) and value.strip(): + return value.strip().rstrip("/") + "/" + return None + + +def cdp_ws_url_from_browser_like(browser: Any) -> str: + cdp_ws_url = getattr(browser, "cdp_ws_url", None) + if isinstance(cdp_ws_url, str) and cdp_ws_url: + return cdp_ws_url + if isinstance(browser, Mapping): + mapping = cast(Mapping[str, object], browser) + value = mapping.get("cdp_ws_url") + if isinstance(value, str) and value: + return value + raise TypeError("browser object must have a non-empty cdp_ws_url") diff --git a/src/kernel/resources/browsers/browsers.py b/src/kernel/resources/browsers/browsers.py index 228e653a..e8524579 100644 --- a/src/kernel/resources/browsers/browsers.py +++ b/src/kernel/resources/browsers/browsers.py @@ -3,7 +3,8 @@ from __future__ import annotations import typing_extensions -from typing import Dict, Mapping, Iterable, Optional, cast +from typing import Dict, Mapping, Iterable, Iterator, Optional, AsyncIterator, cast +from contextlib import contextmanager, asynccontextmanager from typing_extensions import Literal import httpx @@ -78,8 +79,15 @@ ) from ...pagination import SyncOffsetPagination, AsyncOffsetPagination from ..._base_client import AsyncPaginator, make_request_options +from ...lib.browser_routing.routing import browser_route_from_browser from ...types.browser_curl_response import BrowserCurlResponse from ...types.browser_list_response import BrowserListResponse +from ...lib.browser_routing.raw_http import ( + stream_via_browser_route, + request_via_browser_route, + async_stream_via_browser_route, + async_request_via_browser_route, +) from ...types.browser_create_response import BrowserCreateResponse from ...types.browser_update_response import BrowserUpdateResponse from ...types.browser_persistence_param import BrowserPersistenceParam @@ -219,7 +227,7 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ - return self._post( + result = self._post( "/browsers", body=maybe_transform( { @@ -242,6 +250,10 @@ def create( ), cast_to=BrowserCreateResponse, ) + route = browser_route_from_browser(result) + if route is not None: + self._client.browser_route_cache.set(route) + return result def retrieve( self, @@ -271,7 +283,7 @@ def retrieve( """ if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") - return self._get( + result = self._get( path_template("/browsers/{id}", id=id), options=make_request_options( extra_headers=extra_headers, @@ -284,6 +296,10 @@ def retrieve( ), cast_to=BrowserRetrieveResponse, ) + route = browser_route_from_browser(result) + if route is not None: + self._client.browser_route_cache.set(route) + return result def update( self, @@ -325,7 +341,7 @@ def update( """ if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") - return self._patch( + result = self._patch( path_template("/browsers/{id}", id=id), body=maybe_transform( { @@ -341,6 +357,10 @@ def update( ), cast_to=BrowserUpdateResponse, ) + route = browser_route_from_browser(result) + if route is not None: + self._client.browser_route_cache.set(route) + return result def list( self, @@ -383,7 +403,7 @@ def list( timeout: Override the client-level default timeout for this request, in seconds """ - return self._get_api_list( + page = self._get_api_list( "/browsers", page=SyncOffsetPagination[BrowserListResponse], options=make_request_options( @@ -404,6 +424,11 @@ def list( ), model=BrowserListResponse, ) + for item in page.items: + route = browser_route_from_browser(item) + if route is not None: + self._client.browser_route_cache.set(route) + return page @typing_extensions.deprecated("deprecated") def delete( @@ -510,6 +535,64 @@ def curl( cast_to=BrowserCurlResponse, ) + def request( + self, + id: str, + method: str, + url: str, + *, + content: bytes | bytearray | memoryview | str | Iterable[bytes] | None = None, + json: Body | None = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, object] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> httpx.Response: + route = self._client.browser_route_cache.get(id) + if route is None: + raise ValueError( + f"browser route cache does not contain session {id}; create, retrieve, or list the browser before calling browsers.request" + ) + return request_via_browser_route( + self._client, + route, + method, + url, + content=content, + json=json, + headers=headers, + params=params, + timeout=timeout, + ) + + @contextmanager + def stream( + self, + id: str, + method: str, + url: str, + *, + content: bytes | bytearray | memoryview | str | Iterable[bytes] | None = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, object] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Iterator[httpx.Response]: + route = self._client.browser_route_cache.get(id) + if route is None: + raise ValueError( + f"browser route cache does not contain session {id}; create, retrieve, or list the browser before calling browsers.stream" + ) + with stream_via_browser_route( + self._client, + route, + method, + url, + content=content, + headers=headers, + params=params, + timeout=timeout, + ) as resp: + yield resp + def delete_by_id( self, id: str, @@ -719,7 +802,7 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ - return await self._post( + result = await self._post( "/browsers", body=await async_maybe_transform( { @@ -742,6 +825,10 @@ async def create( ), cast_to=BrowserCreateResponse, ) + route = browser_route_from_browser(result) + if route is not None: + self._client.browser_route_cache.set(route) + return result async def retrieve( self, @@ -771,7 +858,7 @@ async def retrieve( """ if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") - return await self._get( + result = await self._get( path_template("/browsers/{id}", id=id), options=make_request_options( extra_headers=extra_headers, @@ -784,6 +871,10 @@ async def retrieve( ), cast_to=BrowserRetrieveResponse, ) + route = browser_route_from_browser(result) + if route is not None: + self._client.browser_route_cache.set(route) + return result async def update( self, @@ -825,7 +916,7 @@ async def update( """ if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") - return await self._patch( + result = await self._patch( path_template("/browsers/{id}", id=id), body=await async_maybe_transform( { @@ -841,6 +932,10 @@ async def update( ), cast_to=BrowserUpdateResponse, ) + route = browser_route_from_browser(result) + if route is not None: + self._client.browser_route_cache.set(route) + return result def list( self, @@ -883,7 +978,7 @@ def list( timeout: Override the client-level default timeout for this request, in seconds """ - return self._get_api_list( + page = self._get_api_list( "/browsers", page=AsyncOffsetPagination[BrowserListResponse], options=make_request_options( @@ -904,6 +999,12 @@ def list( ), model=BrowserListResponse, ) + typed_page = cast(AsyncOffsetPagination[BrowserListResponse], page) + for item in typed_page.items: + route = browser_route_from_browser(item) + if route is not None: + self._client.browser_route_cache.set(route) + return page @typing_extensions.deprecated("deprecated") async def delete( @@ -1012,6 +1113,64 @@ async def curl( cast_to=BrowserCurlResponse, ) + async def request( + self, + id: str, + method: str, + url: str, + *, + content: bytes | bytearray | memoryview | str | Iterable[bytes] | None = None, + json: Body | None = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, object] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> httpx.Response: + route = self._client.browser_route_cache.get(id) + if route is None: + raise ValueError( + f"browser route cache does not contain session {id}; create, retrieve, or list the browser before calling browsers.request" + ) + return await async_request_via_browser_route( + self._client, + route, + method, + url, + content=content, + json=json, + headers=headers, + params=params, + timeout=timeout, + ) + + @asynccontextmanager + async def stream( + self, + id: str, + method: str, + url: str, + *, + content: bytes | bytearray | memoryview | str | Iterable[bytes] | None = None, + headers: Mapping[str, str] | None = None, + params: Mapping[str, object] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncIterator[httpx.Response]: + route = self._client.browser_route_cache.get(id) + if route is None: + raise ValueError( + f"browser route cache does not contain session {id}; create, retrieve, or list the browser before calling browsers.stream" + ) + async with async_stream_via_browser_route( + self._client, + route, + method, + url, + content=content, + headers=headers, + params=params, + timeout=timeout, + ) as resp: + yield resp + async def delete_by_id( self, id: str, diff --git a/tests/test_browser_routing.py b/tests/test_browser_routing.py new file mode 100644 index 00000000..39d7e0cb --- /dev/null +++ b/tests/test_browser_routing.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import os +from typing import Any, cast + +import httpx +import respx +import pytest + +from kernel import Kernel +from kernel.lib.browser_routing.util import jwt_from_cdp_ws_url +from kernel.lib.browser_routing.routing import ( + BrowserRoute, + BrowserRouteCache, + browser_route_from_browser, + browser_routing_config_from_env, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +api_key = "sk-123" + + +def _fake_browser() -> dict[str, object]: + return { + "session_id": "sess-1", + "base_url": "http://browser-session.test/browser/kernel", + "cdp_ws_url": "wss://browser-session.test/browser/cdp?jwt=token-abc", + "webdriver_ws_url": "wss://x", + "created_at": "2020-01-01T00:00:00Z", + "headless": True, + "stealth": False, + "timeout_seconds": 60, + } + + +def _cache_browser(client: Kernel) -> None: + route = browser_route_from_browser(_fake_browser()) + assert route is not None + client.browser_route_cache.set(route) + + +def test_jwt_from_cdp_ws_url() -> None: + assert jwt_from_cdp_ws_url("wss://h/browser/cdp?jwt=abc%2Fdef&x=1") == "abc/def" + + +@respx.mock +def test_routes_allowlisted_browser_subresources_directly_to_vm(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("KERNEL_BROWSER_ROUTING_SUBRESOURCES", "process") + route = respx.post("http://browser-session.test/browser/kernel/process/exec").mock( + return_value=httpx.Response(200, json={"exit_code": 0, "stdout_b64": "", "stderr_b64": ""}) + ) + with Kernel( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + ) as client: + _cache_browser(client) + out = client.browsers.process.exec("sess-1", command="echo", args=["hi"]) + + assert route.called + request = cast(httpx.Request, cast(Any, route.calls[0]).request) + assert request.url.params.get("jwt") == "token-abc" + assert request.headers.get("Authorization") is None + assert out.exit_code == 0 + + +@respx.mock +def test_skips_direct_vm_routing_outside_allowlist(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("KERNEL_BROWSER_ROUTING_SUBRESOURCES", "computer") + route = respx.post(f"{base_url}/browsers/sess-1/process/exec").mock( + return_value=httpx.Response(200, json={"exit_code": 0, "stdout_b64": "", "stderr_b64": ""}) + ) + with Kernel( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + ) as client: + _cache_browser(client) + client.browsers.process.exec("sess-1", command="echo", args=["hi"]) + + assert route.called + + +@respx.mock +def test_browser_request_uses_curl_raw() -> None: + route = respx.get("http://browser-session.test/browser/kernel/curl/raw").mock( + return_value=httpx.Response(200, content=b"ok") + ) + with Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True) as client: + _cache_browser(client) + response = client.browsers.request("sess-1", "GET", "https://example.com", params={"timeout_ms": 5000}) + + assert response.status_code == 200 + assert response.content == b"ok" + request = cast(httpx.Request, cast(Any, route.calls[0]).request) + assert "curl/raw" in str(request.url) + assert request.url.params.get("jwt") == "token-abc" + + +@respx.mock +def test_browser_request_params_cannot_override_target_url_or_jwt() -> None: + route = respx.get("http://browser-session.test/browser/kernel/curl/raw").mock( + return_value=httpx.Response(200, content=b"ok") + ) + with Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True) as client: + _cache_browser(client) + client.browsers.request( + "sess-1", + "GET", + "https://example.com", + params={"url": "https://evil.example", "jwt": "other", "timeout_ms": 1}, + ) + + request = cast(httpx.Request, cast(Any, route.calls[0]).request) + assert str(request.url.params.get("url")) == "https://example.com" + assert str(request.url.params.get("jwt")) == "token-abc" + assert str(request.url.params.get("timeout_ms")) == "1" + + +def test_browser_request_requires_cached_route() -> None: + with Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True) as client: + _cache_browser(client) + client.browser_route_cache.delete("sess-1") + with pytest.raises(ValueError, match="route cache"): + client.browsers.request("sess-1", "GET", "https://example.com") + + +def test_browser_route_cache_normalizes_session_id_keys() -> None: + cache = BrowserRouteCache() + cache.set( + BrowserRoute( + session_id=" sess-1 ", + base_url=" http://browser-session.test/browser/kernel/ ", + jwt=" token-abc ", + ) + ) + + route = cache.get("sess-1") + assert route is not None + assert route.session_id == "sess-1" + assert route.base_url == "http://browser-session.test/browser/kernel/" + assert route.jwt == "token-abc" + + cache.delete("sess-1") + assert cache.get("sess-1") is None + + +def test_browser_route_from_browser_requires_base_url_and_jwt() -> None: + assert browser_route_from_browser({**_fake_browser(), "base_url": None}) is None + assert browser_route_from_browser({**_fake_browser(), "cdp_ws_url": None}) is None + + +def test_browser_routing_config_from_env_defaults_to_curl(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("KERNEL_BROWSER_ROUTING_SUBRESOURCES", raising=False) + assert browser_routing_config_from_env().subresources == ("curl",) + + +def test_browser_routing_config_from_env_empty_string_disables_routing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("KERNEL_BROWSER_ROUTING_SUBRESOURCES", "") + assert browser_routing_config_from_env().subresources == ()