diff --git a/pyproject.toml b/pyproject.toml index 4f09349ce..6067afb4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,6 +200,8 @@ server = [ "python-json-logger>=3.1.0", "prometheus-client", "grpcio>=1.50", + "protobuf>=6.33.5", + "smg-grpc-proto>=0.4.7", ] aws = [ "boto3>=1.38.13", diff --git a/src/dstack/_internal/cli/commands/server.py b/src/dstack/_internal/cli/commands/server.py index a9040274d..255f92e9a 100644 --- a/src/dstack/_internal/cli/commands/server.py +++ b/src/dstack/_internal/cli/commands/server.py @@ -80,6 +80,9 @@ def _command(self, args: argparse.Namespace): os.environ["DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT"] = "1" if args.token: os.environ["DSTACK_SERVER_ADMIN_TOKEN"] = args.token + # Hide noisy "Other threads are currently calling into gRPC, skipping fork() handlers" + # messages in server logs. Users can still change this with GRPC_VERBOSITY. + os.environ.setdefault("GRPC_VERBOSITY", "ERROR") uvicorn_log_level = os.getenv("DSTACK_SERVER_UVICORN_LOG_LEVEL", "ERROR").lower() reload_disabled = os.getenv("DSTACK_SERVER_RELOAD_DISABLED") is not None diff --git a/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py b/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py new file mode 100644 index 000000000..bc6f6cffe --- /dev/null +++ b/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py @@ -0,0 +1,57 @@ +"""SSH-tunneled gRPC channel target to a job's service port (UDS).""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import timedelta +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any + +import grpc + +from dstack._internal.core.services.ssh.tunnel import ( + SSH_DEFAULT_OPTIONS, + IPSocket, + SocketPair, + UnixSocket, +) +from dstack._internal.server.models import JobModel +from dstack._internal.server.services.jobs import get_job_spec +from dstack._internal.server.services.ssh import container_ssh_tunnel +from dstack._internal.utils.common import get_or_error + +SSH_CONNECT_TIMEOUT = timedelta(seconds=10) +# Match router_worker_sync HTTP server_info cap (_MAX_SERVER_INFO_RESPONSE_BYTES). +_MAX_GRPC_MESSAGE_BYTES = 256 * 1024 +_GRPC_CHANNEL_OPTIONS = ( + ("grpc.max_receive_message_length", _MAX_GRPC_MESSAGE_BYTES), + ("grpc.max_send_message_length", _MAX_GRPC_MESSAGE_BYTES), +) + + +@asynccontextmanager +async def get_service_replica_grpc_client(job: JobModel) -> AsyncGenerator[Any, None]: + options = { + **SSH_DEFAULT_OPTIONS, + "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())), + } + job_spec = get_job_spec(job) + with TemporaryDirectory() as temp_dir: + # Keep the same socket file name as the HTTP helper for consistency. + app_socket_path = (Path(temp_dir) / "replica.sock").absolute() + async with container_ssh_tunnel( + job=job, + forwarded_sockets=[ + SocketPair( + remote=IPSocket("localhost", get_or_error(job_spec.service_port)), + local=UnixSocket(app_socket_path), + ), + ], + options=options, + ): + target = f"unix://{app_socket_path}" + channel = grpc.aio.insecure_channel(target, options=_GRPC_CHANNEL_OPTIONS) + try: + yield channel + finally: + await channel.close() diff --git a/src/dstack/_internal/server/services/runs/router_worker_sync.py b/src/dstack/_internal/server/services/runs/router_worker_sync.py index 2fc9add74..910dc8d57 100644 --- a/src/dstack/_internal/server/services/runs/router_worker_sync.py +++ b/src/dstack/_internal/server/services/runs/router_worker_sync.py @@ -1,10 +1,25 @@ """Reconcile SGLang router /workers with dstack's registered worker replicas (async, SSH-tunneled).""" import json -from typing import Any, Dict, List, Literal, Optional, TypedDict +from typing import Any, List, Literal, Optional, TypedDict from urllib.parse import urlsplit, urlunsplit -from httpx import AsyncClient, Response +import grpc +from google.protobuf.json_format import MessageToDict +from httpx import ( + AsyncClient, + ConnectError, + ConnectTimeout, + ReadTimeout, + RemoteProtocolError, + Response, +) +from smg_grpc_proto import ( + sglang_scheduler_pb2, + sglang_scheduler_pb2_grpc, + vllm_engine_pb2, + vllm_engine_pb2_grpc, +) from typing_extensions import NotRequired from dstack._internal.core.errors import SSHError @@ -12,6 +27,9 @@ from dstack._internal.core.models.runs import JobStatus, RunSpec, get_service_port from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_spec +from dstack._internal.server.services.jobs.job_replica_grpc_client import ( + get_service_replica_grpc_client, +) from dstack._internal.server.services.jobs.job_replica_http_client import ( get_service_replica_client, ) @@ -29,6 +47,7 @@ _MAX_WORKERS_RESPONSE_BYTES = 2 * 1024 * 1024 _MAX_WORKERS_COMMAND_ACK_BYTES = 64 * 1024 _MAX_WORKERS_LIST_ITEMS = 8192 +_GRPC_DISCOVERY_TIMEOUT = 30.0 class _ResponseTooLargeError(Exception): @@ -82,15 +101,24 @@ async def _request_json_limited( return None -class _WorkerPayloadResult(TypedDict): - status: Literal["ready", "not_ready"] - payload: Optional[Dict[str, Any]] - - class _TargetWorker(TypedDict): url: str worker_type: str bootstrap_port: NotRequired[Optional[int]] + connection_mode: NotRequired[str] + runtime_type: NotRequired[str] + kv_connector: NotRequired[str] + kv_role: NotRequired[str] + + +class _WorkerPayloadResult(TypedDict): + status: Literal["ready", "not_ready"] + worker: Optional[_TargetWorker] + + +_ConnectionMode = Literal["grpc", "http"] +_RuntimeType = Literal["sglang", "vllm"] +_GRPC_RUNTIME_TYPES: tuple[_RuntimeType, ...] = ("sglang", "vllm") def run_model_has_sglang_router_replica_group(run_model: RunModel) -> bool: @@ -121,6 +149,70 @@ def _normalize_worker_url(url: str) -> str: return urlunsplit((parts.scheme, parts.netloc, path, parts.query, parts.fragment)) +def _get_connection_mode_from_workers( + current_workers: List[dict], +) -> Optional[_ConnectionMode]: + # PD services register multiple workers (e.g. prefill and decode). We expect + # every listed worker to use the same connection_mode (all grpc or all http), + # not a mix of protocols on one router. + modes: set[str] = set() + for worker in current_workers: + mode = worker.get("connection_mode") + if isinstance(mode, str) and mode in ("http", "grpc"): + modes.add(mode) + if modes == {"grpc"}: + return "grpc" + if modes == {"http"}: + return "http" + return None + + +def _get_runtime_type_from_workers( + current_workers: List[dict], +) -> Optional[_RuntimeType]: + # We expect every listed gRPC worker to share the same runtime_type + # (all sglang or all vllm), not a mix of runtimes on one router. + runtimes: set[str] = set() + for worker in current_workers: + # For HTTP workers,there is no “pick vLLM vs SGLang gRPC stub” step, + # so runtime_type is irrelevant for HTTP workers. + if worker.get("connection_mode") != "grpc": + continue + runtime_type = worker.get("runtime_type") + if isinstance(runtime_type, str) and runtime_type in _GRPC_RUNTIME_TYPES: + runtimes.add(runtime_type) + if runtimes == {"sglang"}: + return "sglang" + if runtimes == {"vllm"}: + return "vllm" + return None + + +def _is_expected_router_workers_fetch_error(error: Exception) -> bool: + """SMG router may not accept HTTP yet during startup.""" + if isinstance( + error, + ( + RemoteProtocolError, + ConnectError, + ConnectTimeout, + ReadTimeout, + TimeoutError, + ), + ): + return True + if isinstance(error, OSError) and error.errno in {61, 111}: + return True + return False + + +def _log_router_workers_fetch_failure(error: Exception) -> None: + if _is_expected_router_workers_fetch_error(error): + logger.debug("Router /workers not ready yet: %r", error) + return + logger.exception("Error getting router /workers") + + async def _get_router_workers(client: AsyncClient) -> List[dict]: try: data = await _request_json_limited( @@ -144,8 +236,8 @@ async def _get_router_workers(client: AsyncClient) -> List[dict]: return [w for w in workers if isinstance(w, dict)] except _ResponseTooLargeError: logger.warning("Router /workers response exceeded size limit") - except Exception: - logger.exception("Error getting router /workers") + except Exception as e: + _log_router_workers_fetch_failure(e) return [] @@ -154,11 +246,24 @@ async def _add_worker_to_router( url: str, worker_type: str = "regular", bootstrap_port: Optional[int] = None, + *, + connection_mode: Optional[str] = None, + runtime_type: Optional[str] = None, + kv_connector: Optional[str] = None, + kv_role: Optional[str] = None, ) -> bool: try: payload: dict = {"url": url, "worker_type": worker_type} if bootstrap_port is not None: payload["bootstrap_port"] = bootstrap_port + if connection_mode is not None: + payload["connection_mode"] = connection_mode + if runtime_type is not None: + payload["runtime_type"] = runtime_type + if kv_connector is not None: + payload["kv_connector"] = kv_connector + if kv_role is not None: + payload["kv_role"] = kv_role body = await _request_json_limited( client, "POST", @@ -199,11 +304,12 @@ async def _remove_worker_from_router_by_id( async def _update_workers_in_router_replica( client: AsyncClient, target_workers: List[_TargetWorker], + *, + current_workers: List[dict], ) -> None: - current = await _get_router_workers(client) current_urls: set[str] = set() current_ids_by_norm_url: dict[str, str] = {} - for w in current: + for w in current_workers: u = w.get("url") if not isinstance(u, str) or not u: continue @@ -223,6 +329,10 @@ async def _update_workers_in_router_replica( tw["url"], tw["worker_type"], tw.get("bootstrap_port"), + connection_mode=tw.get("connection_mode"), + runtime_type=tw.get("runtime_type"), + kv_connector=tw.get("kv_connector"), + kv_role=tw.get("kv_role"), ) if not ok: logger.warning("Failed to add worker %s, continuing with others", tw["url"]) @@ -237,7 +347,26 @@ async def _update_workers_in_router_replica( logger.warning("Failed to remove worker %s, continuing with others", url) -async def _get_worker_payload(job_model: JobModel, worker_url: str) -> _WorkerPayloadResult: +def _vllm_kv_role_to_worker_type(kv_role: str) -> str: + if kv_role == "kv_producer": + return "prefill" + if kv_role == "kv_consumer": + return "decode" + return "regular" + + +def _is_expected_grpc_discovery_error(error: Exception) -> bool: + """Expected while a gRPC worker is still starting or the wrong stub is probed.""" + if isinstance(error, grpc.aio.AioRpcError): + return error.code() in ( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.UNIMPLEMENTED, + ) + return False + + +async def _get_http_worker(job_model: JobModel, *, worker_url: str) -> _WorkerPayloadResult: try: async with get_service_replica_client(job_model) as client: data = await _request_json_limited( @@ -249,43 +378,189 @@ async def _get_worker_payload(job_model: JobModel, worker_url: str) -> _WorkerPa ) if isinstance(data, dict): if data.get("status") != "ready": - return {"status": "not_ready", "payload": None} + return {"status": "not_ready", "worker": None} mode = data.get("disaggregation_mode", "") if mode == "prefill": bootstrap_port = data.get("disaggregation_bootstrap_port") - return { - "status": "ready", - "payload": { - "url": worker_url, - "worker_type": "prefill", - "bootstrap_port": bootstrap_port, - }, + worker: _TargetWorker = { + "url": worker_url, + "worker_type": "prefill", + "connection_mode": "http", + "runtime_type": "sglang", } + if bootstrap_port is not None: + worker["bootstrap_port"] = bootstrap_port + return {"status": "ready", "worker": worker} if mode == "decode": return { "status": "ready", - "payload": {"url": worker_url, "worker_type": "decode"}, + "worker": { + "url": worker_url, + "worker_type": "decode", + "connection_mode": "http", + "runtime_type": "sglang", + }, } return { "status": "ready", - "payload": {"url": worker_url, "worker_type": "regular"}, + "worker": { + "url": worker_url, + "worker_type": "regular", + "connection_mode": "http", + "runtime_type": "sglang", + }, } except _ResponseTooLargeError: logger.warning("server_info response too large for worker %s", worker_url) + except RemoteProtocolError as e: + logger.debug("HTTP server_info not available for worker %s: %r", worker_url, e) except Exception as e: logger.exception("Could not fetch server_info for worker %s: %r", worker_url, e) - return {"status": "not_ready", "payload": None} + return {"status": "not_ready", "worker": None} + + +async def _get_grpc_server_info( + channel: grpc.aio.Channel, + runtime_type: _RuntimeType, +) -> Any: + if runtime_type == "sglang": + stub = sglang_scheduler_pb2_grpc.SglangSchedulerStub(channel) + request = sglang_scheduler_pb2.GetServerInfoRequest() + else: + stub = vllm_engine_pb2_grpc.VllmEngineStub(channel) + request = vllm_engine_pb2.GetServerInfoRequest() + return await stub.GetServerInfo(request, timeout=_GRPC_DISCOVERY_TIMEOUT) + + +async def _discover_grpc_server_info( + channel: grpc.aio.Channel, +) -> tuple[Optional[_RuntimeType], Optional[Any]]: + # Bootstrap only: router workers list has no runtime_type yet. + for runtime_type in _GRPC_RUNTIME_TYPES: + try: + response = await _get_grpc_server_info(channel, runtime_type) + except Exception as e: + if _is_expected_grpc_discovery_error(e): + continue + raise + return runtime_type, response + return None, None + + +def _grpc_server_info_to_worker( + worker_url: str, + runtime_type: _RuntimeType, + response: Any, +) -> _TargetWorker: + if runtime_type == "vllm": + kv_role = response.kv_role or "" + kv_connector = response.kv_connector or "" + worker: _TargetWorker = { + "url": worker_url, + "connection_mode": "grpc", + "runtime_type": runtime_type, + "worker_type": _vllm_kv_role_to_worker_type(kv_role), + } + if kv_connector: + worker["kv_connector"] = kv_connector + if kv_role: + worker["kv_role"] = kv_role + return worker + + server_args = ( + MessageToDict(response.server_args, preserving_proto_field_name=True) + if response.server_args is not None + else {} + ) + mode = server_args.get("disaggregation_mode") + worker_type = mode if mode in ("prefill", "decode") else "regular" + worker = { + "url": worker_url, + "connection_mode": "grpc", + "runtime_type": runtime_type, + "worker_type": worker_type, + } + if worker_type == "prefill": + bootstrap_port = server_args.get("disaggregation_bootstrap_port") + if bootstrap_port is not None: + worker["bootstrap_port"] = int(bootstrap_port) + return worker + + +async def _get_grpc_worker( + job_model: JobModel, + *, + worker_url: str, + runtime_type: Optional[_RuntimeType] = None, +) -> _WorkerPayloadResult: + try: + async with get_service_replica_grpc_client(job_model) as channel: + if runtime_type is not None: + try: + response = await _get_grpc_server_info(channel, runtime_type) + except Exception as e: + if _is_expected_grpc_discovery_error(e): + logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) + return {"status": "not_ready", "worker": None} + raise + else: + runtime_type, response = await _discover_grpc_server_info(channel) + if runtime_type is None or response is None: + logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) + return {"status": "not_ready", "worker": None} + except Exception as e: + logger.exception( + "Could not fetch gRPC GetServerInfo for worker %s: %r", + worker_url, + e, + ) + return {"status": "not_ready", "worker": None} + + worker = _grpc_server_info_to_worker(worker_url, runtime_type, response) + return {"status": "ready", "worker": worker} + + +async def _get_worker( + job_model: JobModel, + *, + http_worker_url: str, + grpc_worker_url: str, + connection_mode: Optional[_ConnectionMode] = None, + runtime_type: Optional[_RuntimeType] = None, +) -> _WorkerPayloadResult: + if connection_mode == "grpc": + return await _get_grpc_worker( + job_model, worker_url=grpc_worker_url, runtime_type=runtime_type + ) + if connection_mode == "http": + return await _get_http_worker(job_model, worker_url=http_worker_url) + # Router workers list is empty and no connection_mode discovered. + try: + result = await _get_http_worker(job_model, worker_url=http_worker_url) + except RemoteProtocolError as e: + logger.debug( + "HTTP server_info probe failed for %s (trying gRPC): %r", + http_worker_url, + e, + ) + result: _WorkerPayloadResult = {"status": "not_ready", "worker": None} + if result["status"] == "ready": + return result + return await _get_grpc_worker(job_model, worker_url=grpc_worker_url, runtime_type=runtime_type) async def _build_target_workers( run_model: RunModel, run_spec: RunSpec, replica_groups: list[ReplicaGroup], + *, + connection_mode: Optional[_ConnectionMode] = None, + runtime_type: Optional[_RuntimeType] = None, ) -> List[_TargetWorker]: - payloads: List[_TargetWorker] = [] + workers: List[_TargetWorker] = [] config = run_spec.configuration if not isinstance(config, ServiceConfiguration): - return payloads + return workers for group in replica_groups: if group.router is not None: @@ -305,20 +580,24 @@ async def _build_target_workers( continue job_spec = get_job_spec(job) port = get_service_port(job_spec, config) - worker_url = f"http://{ip}:{port}" - result = await _get_worker_payload(job, worker_url) - if result["status"] == "ready" and result["payload"]: - p = result["payload"] - entry: _TargetWorker = { - "url": p["url"], - "worker_type": p.get("worker_type", "regular"), - } - if p.get("bootstrap_port") is not None: - entry["bootstrap_port"] = p["bootstrap_port"] - payloads.append(entry) + http_worker_url = f"http://{ip}:{port}" + grpc_worker_url = f"grpc://{ip}:{port}" + result = await _get_worker( + job, + http_worker_url=http_worker_url, + grpc_worker_url=grpc_worker_url, + connection_mode=connection_mode, + runtime_type=runtime_type, + ) + if result["status"] == "ready" and result["worker"]: + workers.append(result["worker"]) elif result["status"] == "not_ready": - logger.debug("Worker %s not ready", worker_url) - return payloads + logger.debug( + "Worker not ready http=%s grpc=%s", + http_worker_url, + grpc_worker_url, + ) + return workers async def sync_router_workers_for_run_model(run_model: RunModel) -> None: @@ -331,13 +610,28 @@ async def sync_router_workers_for_run_model(run_model: RunModel) -> None: if router_group is None: return - target_workers = await _build_target_workers(run_model, run_spec, replica_groups) router_job = _get_router_job(run_model, router_group) if router_job is None: return try: async with get_service_replica_client(router_job) as client: - await _update_workers_in_router_replica(client, target_workers) + current_workers = await _get_router_workers(client) + # connection_mode can be grpc or http, runtime_type can be sglang or vllm. + connection_mode = _get_connection_mode_from_workers(current_workers) + runtime_type = _get_runtime_type_from_workers(current_workers) + # Empty current_workers on first sync is expected. First syncprobes both connection_mode and + # runtime_type. Subsequent syncs don't need to probe again because connection_mode and runtime_type + # is already set in current_workers. + target_workers = await _build_target_workers( + run_model, + run_spec, + replica_groups, + connection_mode=connection_mode, + runtime_type=runtime_type, + ) + await _update_workers_in_router_replica( + client, target_workers, current_workers=current_workers + ) except SSHError as e: logger.warning( "%s: failed to sync workers with router: %r", diff --git a/src/tests/_internal/server/services/runs/test_router_worker_sync.py b/src/tests/_internal/server/services/runs/test_router_worker_sync.py new file mode 100644 index 000000000..2cf027563 --- /dev/null +++ b/src/tests/_internal/server/services/runs/test_router_worker_sync.py @@ -0,0 +1,231 @@ +from contextlib import asynccontextmanager, contextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from dstack._internal.server.services.runs.router_worker_sync import ( + _get_connection_mode_from_workers, + _get_grpc_worker, + _get_runtime_type_from_workers, + _get_worker, + _grpc_server_info_to_worker, +) + + +class TestGetConnectionModeFromWorkers: + def test_grpc(self): + current = [{"connection_mode": "grpc"}] + assert _get_connection_mode_from_workers(current) == "grpc" + + def test_http(self): + current = [{"connection_mode": "http"}] + assert _get_connection_mode_from_workers(current) == "http" + + def test_mixed(self): + current = [{"connection_mode": "grpc"}, {"connection_mode": "http"}] + assert _get_connection_mode_from_workers(current) is None + + +class TestRuntimeTypeFromRouterWorkers: + def test_vllm_grpc_workers(self): + current = [{"connection_mode": "grpc", "runtime_type": "vllm"}] + assert _get_runtime_type_from_workers(current) == "vllm" + + def test_sglang_grpc_workers(self): + current = [{"connection_mode": "grpc", "runtime_type": "sglang"}] + assert _get_runtime_type_from_workers(current) == "sglang" + + def test_ignores_http_workers(self): + current = [{"connection_mode": "http", "runtime_type": "sglang"}] + assert _get_runtime_type_from_workers(current) is None + + def test_mixed_runtimes(self): + current = [ + {"connection_mode": "grpc", "runtime_type": "vllm"}, + {"connection_mode": "grpc", "runtime_type": "sglang"}, + ] + assert _get_runtime_type_from_workers(current) is None + + +class TestGrpcServerInfoToWorker: + def test_vllm_prefill(self): + response = MagicMock(kv_role="kv_producer", kv_connector="NixlConnector") + worker = _grpc_server_info_to_worker("grpc://10.0.0.1:50051", "vllm", response) + assert worker["worker_type"] == "prefill" + assert worker.get("runtime_type") == "vllm" + assert worker.get("kv_role") == "kv_producer" + + def test_sglang_prefill(self): + server_args = MagicMock() + response = MagicMock(server_args=server_args) + with patch( + "dstack._internal.server.services.runs.router_worker_sync.MessageToDict", + return_value={ + "disaggregation_mode": "prefill", + "disaggregation_bootstrap_port": 8998, + }, + ): + worker = _grpc_server_info_to_worker("grpc://10.0.0.1:8000", "sglang", response) + assert worker == { + "url": "grpc://10.0.0.1:8000", + "worker_type": "prefill", + "connection_mode": "grpc", + "runtime_type": "sglang", + "bootstrap_port": 8998, + } + + +@contextmanager +def _fake_vllm_grpc_proto(*, server_info: MagicMock): + stub = MagicMock() + stub.GetServerInfo = AsyncMock(return_value=server_info) + pb2 = MagicMock(GetServerInfoRequest=MagicMock(return_value="req")) + pb2_grpc = MagicMock(VllmEngineStub=MagicMock(return_value=stub)) + with ( + patch( + "dstack._internal.server.services.runs.router_worker_sync.vllm_engine_pb2", + pb2, + ), + patch( + "dstack._internal.server.services.runs.router_worker_sync.vllm_engine_pb2_grpc", + pb2_grpc, + ), + ): + yield + + +@contextmanager +def _fake_sglang_grpc_proto(*, server_info: MagicMock): + stub = MagicMock() + stub.GetServerInfo = AsyncMock(return_value=server_info) + pb2 = MagicMock(GetServerInfoRequest=MagicMock(return_value="req")) + pb2_grpc = MagicMock(SglangSchedulerStub=MagicMock(return_value=stub)) + with ( + patch( + "dstack._internal.server.services.runs.router_worker_sync.sglang_scheduler_pb2", + pb2, + ), + patch( + "dstack._internal.server.services.runs.router_worker_sync.sglang_scheduler_pb2_grpc", + pb2_grpc, + ), + ): + yield + + +@pytest.mark.asyncio +async def test_get_grpc_worker_ready(): + job = MagicMock() + channel = MagicMock() + + @asynccontextmanager + async def _fake_grpc_client(_job): + yield channel + + server_info = MagicMock(kv_role="kv_producer", kv_connector="NixlConnector") + + with ( + _fake_vllm_grpc_proto(server_info=server_info), + patch( + "dstack._internal.server.services.runs.router_worker_sync.get_service_replica_grpc_client", + _fake_grpc_client, + ), + ): + result = await _get_grpc_worker( + job, + worker_url="grpc://10.0.0.1:50051", + runtime_type="vllm", + ) + + assert result["status"] == "ready" + assert result["worker"] == { + "url": "grpc://10.0.0.1:50051", + "worker_type": "prefill", + "connection_mode": "grpc", + "runtime_type": "vllm", + "kv_connector": "NixlConnector", + "kv_role": "kv_producer", + } + + +@pytest.mark.asyncio +async def test_get_grpc_worker_not_ready_on_error(): + job = MagicMock() + + @asynccontextmanager + async def _failing_client(_job): + raise OSError("ssh failed") + yield # pragma: no cover + + with patch( + "dstack._internal.server.services.runs.router_worker_sync.get_service_replica_grpc_client", + _failing_client, + ): + result = await _get_grpc_worker(job, worker_url="grpc://10.0.0.1:50051") + + assert result == {"status": "not_ready", "worker": None} + + +@pytest.mark.asyncio +async def test_get_grpc_worker_sglang_bootstrap(): + job = MagicMock() + channel = MagicMock() + sglang_server_info = MagicMock(server_args=MagicMock()) + + @asynccontextmanager + async def _fake_grpc_client(_job): + yield channel + + with ( + _fake_sglang_grpc_proto(server_info=sglang_server_info), + patch( + "dstack._internal.server.services.runs.router_worker_sync.MessageToDict", + return_value={ + "disaggregation_mode": "prefill", + "disaggregation_bootstrap_port": 8998, + }, + ), + patch( + "dstack._internal.server.services.runs.router_worker_sync" + ".get_service_replica_grpc_client", + _fake_grpc_client, + ), + ): + result = await _get_grpc_worker(job, worker_url="grpc://10.0.0.1:8000") + + assert result["status"] == "ready" + assert result["worker"] == { + "url": "grpc://10.0.0.1:8000", + "worker_type": "prefill", + "connection_mode": "grpc", + "runtime_type": "sglang", + "bootstrap_port": 8998, + } + + +@pytest.mark.asyncio +async def test_get_worker_grpc_preference_skips_http(): + job = MagicMock() + grpc_not_ready = {"status": "not_ready", "worker": None} + + with ( + patch( + "dstack._internal.server.services.runs.router_worker_sync._get_grpc_worker", + new_callable=AsyncMock, + return_value=grpc_not_ready, + ) as grpc_mock, + patch( + "dstack._internal.server.services.runs.router_worker_sync._get_http_worker", + new_callable=AsyncMock, + ) as http_mock, + ): + result = await _get_worker( + job, + http_worker_url="http://10.0.0.1:8000", + grpc_worker_url="grpc://10.0.0.1:8000", + connection_mode="grpc", + ) + + assert result == grpc_not_ready + grpc_mock.assert_awaited_once() + http_mock.assert_not_awaited()