From 691421e064f5bb6370cb452224d23d689b82c218 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 3 Jun 2026 15:51:01 +0500 Subject: [PATCH 01/25] Prototype InstanceConnectionPool --- .../_internal/server/services/runner/pool.py | 203 ++++++++++++++++++ .../_internal/server/services/runner/ssh.py | 84 ++------ 2 files changed, 218 insertions(+), 69 deletions(-) create mode 100644 src/dstack/_internal/server/services/runner/pool.py diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py new file mode 100644 index 0000000000..1ea8682624 --- /dev/null +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -0,0 +1,203 @@ +import threading +from dataclasses import dataclass +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Collection, Optional, Union + +from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import SSHError +from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.runs import JobProvisioningData, JobRuntimeData +from dstack._internal.core.services.ssh.tunnel import ( + SSH_DEFAULT_OPTIONS, + IPSocket, + SocketPair, + SSHTunnel, + UnixSocket, +) +from dstack._internal.server.settings import SERVER_DIR_PATH +from dstack._internal.utils.path import FileContent + +# A host private key or pair of (host private key, optional proxy jump private key) +PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] + +CONNECTIONS_DIR = SERVER_DIR_PATH / "instance-connections" + +DEFAULT_PORTS_TO_FORWARD = [DSTACK_SHIM_HTTP_PORT, DSTACK_RUNNER_HTTP_PORT] + + +@dataclass(frozen=True) +class InstanceConnectionKey: + hostname: str + port: int + ports_to_forward: tuple[int, ...] + + @staticmethod + def from_jpd( + jpd: JobProvisioningData, jrd: Optional[JobRuntimeData] + ) -> "InstanceConnectionKey": + assert jpd.hostname is not None and jpd.ssh_port is not None + container_to_host_port_map = InstanceConnection._get_container_to_host_port_map(jpd, jrd) + return InstanceConnectionKey( + hostname=jpd.hostname, + port=jpd.ssh_port, + ports_to_forward=tuple(container_to_host_port_map.values()), + ) + + +class InstanceConnectionPool: + def __init__(self): + self._connections: dict[InstanceConnectionKey, InstanceConnection] + self._access_locks: dict[InstanceConnectionKey, threading.Lock] + self._access_locks_lock = threading.Lock() + + def get_or_open( + self, + ssh_private_key: PrivateKeyOrPair, + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], + ) -> Optional["InstanceConnection"]: + key = InstanceConnectionKey.from_jpd(jpd, jrd) + lock = self._get_access_lock(key) + with lock: + conn = self._connections.get(key) + if conn is not None: + return conn + conn = InstanceConnection(ssh_private_key, jpd, jrd) + try: + conn.open() + except SSHError: + # error logged in tunnel + return None + self._connections[key] = conn + return conn + + def drop(self, key: InstanceConnectionKey) -> None: + lock = self._get_access_lock(key) + with lock: + # close? + try: + self._connections.pop(key) + self._access_locks.pop(key) + except KeyError: + pass + + def close_all(self) -> None: ... # graceful shutdown + + def _get_access_lock(self, key: InstanceConnectionKey) -> threading.Lock: + with self._access_locks_lock: + lock = self._access_locks.get(key) + if lock is not None: + return lock + lock = threading.Lock() + self._access_locks[key] = lock + return lock + + +instance_connection_pool = InstanceConnectionPool() + + +class InstanceConnection: + def __init__( + self, + ssh_private_key: PrivateKeyOrPair, + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], + ) -> None: + self._key = InstanceConnectionKey.from_jpd(jpd, jrd) + self._connection_dir = ( + CONNECTIONS_DIR + / f"{self._key.hostname}:{self._key.port}" + / str(self._key.ports_to_forward) + ) + # connection_dir can have a long path that won't be accepted by the ssh command, + # so we create a short temporary symlink + self._temp_dir, self._connection_symlink_dir = self._init_symlink_dir(self._connection_dir) + self._control_socket_path = self._connection_symlink_dir / "control.sock" + self._container_to_host_port_map = InstanceConnection._get_container_to_host_port_map( + jpd, jrd + ) + self._host_port_to_unix_socket_map = InstanceConnection._get_host_port_to_unix_socket_map( + connection_dir=self._connection_symlink_dir, + ports_to_forward=self._key.ports_to_forward, + ) + self._tunnel = SSHTunnel( + destination=f"{jpd.username}@{jpd.hostname}", + port=jpd.ssh_port, + identity=_get_identity(ssh_private_key, jpd), + control_sock_path=self._control_socket_path, + forwarded_sockets=self._get_forwarded_sockets(self._host_port_to_unix_socket_map), + ssh_proxies=_get_proxies(ssh_private_key, jpd), + options={ + **SSH_DEFAULT_OPTIONS, + "ServerAliveInterval": "30", + "ControlPersist": "2m", + }, + batch_mode=True, + ) + + def open(self) -> None: + self._tunnel.open() + + def forwarded_path(self, container_port: int) -> Path: + return self._host_port_to_unix_socket_map[self._container_to_host_port_map[container_port]] + + def close(self) -> None: + self._tunnel.close() + + @property + def key(self) -> InstanceConnectionKey: + return self._key + + @staticmethod + def _init_symlink_dir(connection_dir: Path) -> tuple[TemporaryDirectory, Path]: + temp_dir = TemporaryDirectory() + symlink_dir = Path(temp_dir.name) / "connection" + symlink_dir.symlink_to(connection_dir, target_is_directory=True) + return temp_dir, symlink_dir + + @staticmethod + def _get_container_to_host_port_map( + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], + ) -> dict[int, int]: + port_map = {port: port for port in DEFAULT_PORTS_TO_FORWARD} + if jrd is not None and jrd.ports is not None: + port_map.update(jrd.ports) + return port_map + + @staticmethod + def _get_host_port_to_unix_socket_map( + connection_dir: Path, + ports_to_forward: Collection[int], + ) -> dict[int, Path]: + return {port: connection_dir / str(port) for port in ports_to_forward} + + @staticmethod + def _get_forwarded_sockets(host_port_to_unix_socket_map: dict[int, Path]) -> list[SocketPair]: + return [ + SocketPair( + local=UnixSocket(path=path), + remote=IPSocket(host="localhost", port=port), + ) + for port, path in host_port_to_unix_socket_map.items() + ] + + +def _get_identity(ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData) -> FileContent: + if isinstance(ssh_private_key, tuple): + ssh_private_key, _ = ssh_private_key + return FileContent(ssh_private_key) + + +def _get_proxies( + ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData +) -> list[tuple[SSHConnectionParams, FileContent]]: + if not isinstance(ssh_private_key, tuple): + return [] + + _, ssh_proxy_private_key = ssh_private_key + if ssh_proxy_private_key is None or jpd.ssh_proxy is None: + return [] + proxy_identity = FileContent(ssh_proxy_private_key) + return [(jpd.ssh_proxy, proxy_identity)] diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index a4ef986862..98d5186181 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -1,7 +1,4 @@ import functools -import socket -import time -from collections.abc import Iterable from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union import requests @@ -10,9 +7,8 @@ from dstack._internal.core.errors import DstackError, SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.runs import JobProvisioningData, JobRuntimeData -from dstack._internal.core.services.ssh.tunnel import SSHTunnel, ports_to_forwarded_sockets +from dstack._internal.server.services.runner.pool import instance_connection_pool from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.path import FileContent logger = get_logger(__name__) P = ParamSpec("P") @@ -47,8 +43,8 @@ def decorator( @functools.wraps(func) def wrapper( ssh_private_key: PrivateKeyOrPair, - job_provisioning_data: JobProvisioningData, - job_runtime_data: Optional[JobRuntimeData], + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], *args: P.args, **kwargs: P.kwargs, ) -> Union[Literal[False], R]: @@ -56,74 +52,24 @@ def wrapper( Returns: is successful """ - # container:host mapping - container_ports_map = {port: port for port in ports} - if job_runtime_data is not None and job_runtime_data.ports is not None: - container_ports_map.update(job_runtime_data.ports) - - if job_provisioning_data.backend == BackendType.LOCAL: + if jpd.backend == BackendType.LOCAL: # without SSH + container_ports_map = {port: port for port in ports} return func(container_ports_map, *args, **kwargs) - if isinstance(ssh_private_key, str): - ssh_proxy_private_key = None - else: - ssh_private_key, ssh_proxy_private_key = ssh_private_key - identity = FileContent(ssh_private_key) - if ssh_proxy_private_key is not None: - proxy_identity = FileContent(ssh_proxy_private_key) - else: - proxy_identity = None - - ssh_proxies = [] - if job_provisioning_data.ssh_proxy is not None: - ssh_proxies.append((job_provisioning_data.ssh_proxy, proxy_identity)) - - for attempt in range(retries): - last = attempt == retries - 1 - # remote_host:local mapping - tunnel_ports_map = _reserve_ports(container_ports_map.values()) - runner_ports_map = { - container_port: tunnel_ports_map[host_port] - for container_port, host_port in container_ports_map.items() - } + for attempt in range(2): # cached, then one fresh reopen + conn = instance_connection_pool.get_or_open(ssh_private_key, jpd, jrd) + if conn is None: + return False # couldn't establish at all + sock_paths = {p: conn.forwarded_path(p) for p in ports} try: - with SSHTunnel( - destination=( - f"{job_provisioning_data.username}@{job_provisioning_data.hostname}" - ), - port=job_provisioning_data.ssh_port, - forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map), - identity=identity, - ssh_proxies=ssh_proxies, - batch_mode=True, - ): - return func(runner_ports_map, *args, **kwargs) - except SSHError: - pass # error is logged in the tunnel - except (DstackError, requests.RequestException) as e: - if last: - logger.debug( - "Cannot connect to %s's API: %s", job_provisioning_data.hostname, e - ) - if not last: - time.sleep(retry_interval) + return func(sock_paths, *args, **kwargs) + except (SSHError, requests.ConnectionError): + instance_connection_pool.drop(conn.key) # dead ssh connection, re-open + except (DstackError, requests.RequestException): + return False # reached runner, app-level fail; don't re-open ssh connection return False return wrapper return decorator - - -def _reserve_ports(ports: Iterable[int]) -> dict[int, int]: - sockets = [] - try: - for port in ports: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("localhost", 0)) # Bind to a free port provided by the host - sockets.append((port, s)) - return {port: s.getsockname()[1] for port, s in sockets} - finally: - for _, s in sockets: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.close() From aaad4be36ba739fb3d47010aca3d5769f69e7dcb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 3 Jun 2026 16:08:17 +0500 Subject: [PATCH 02/25] Make runner and shim client work over uds --- pyproject.toml | 1 + .../server/services/runner/client.py | 71 ++++++++++++++----- .../_internal/server/services/runner/ssh.py | 4 +- 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4f09349ced..2a4e8620d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,6 +186,7 @@ server = [ "aiorwlock", "aiocache", "httpx>=0.28.0", + "requests-unixsocket>=0.4.1", "jinja2", "watchfiles", "sqlalchemy[asyncio]>=2.0.0", diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 6a1c541856..096e8c75ba 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -1,11 +1,14 @@ +import urllib.parse import uuid from collections.abc import Generator from http import HTTPStatus +from pathlib import Path from typing import BinaryIO, Dict, List, Literal, Optional, TypeVar, Union, overload import packaging.version import requests import requests.exceptions +import requests_unixsocket from typing_extensions import Self from dstack._internal.core.errors import DstackError @@ -42,6 +45,7 @@ ) from dstack._internal.utils.common import get_or_error from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.path import PathLike REQUEST_TIMEOUT = 9 UPLOAD_CODE_REQUEST_TIMEOUT = 60 @@ -59,12 +63,20 @@ class RunnerClient: def __init__( self, - port: int, + port: Optional[int] = None, hostname: str = "localhost", + uds: Optional[PathLike] = None, ): - self.secure = False - self.hostname = hostname - self.port = port + self._session, self._base_url = _make_session_and_base_url(port, hostname, uds) + + @classmethod + def from_address(cls, address: Union[int, Path]) -> Self: + """ + Builds a client from a TCP port (`int`) or a Unix domain socket path (`Path`). + """ + if isinstance(address, int): + return cls(port=address) + return cls(uds=address) def get_version_string(self) -> str: if not self._negotiated: @@ -90,7 +102,7 @@ def healthcheck(self) -> Optional[HealthcheckResponse]: return healthcheck_response def get_metrics(self) -> Optional[MetricsResponse]: - resp = requests.get(self._url("/api/metrics"), timeout=REQUEST_TIMEOUT) + resp = self._session.get(self._url("/api/metrics"), timeout=REQUEST_TIMEOUT) if resp.status_code == 404: return None resp.raise_for_status() @@ -134,7 +146,7 @@ def submit_job( log_quota_hour=quota if quota > 0 else None, run_spec=run.run_spec, ) - resp = requests.post( + resp = self._session.post( # use .json() to encode enums self._url("/api/submit"), data=body.json(), @@ -144,7 +156,7 @@ def submit_job( resp.raise_for_status() def upload_archive(self, id: uuid.UUID, file: Union[BinaryIO, bytes]): - resp = requests.post( + resp = self._session.post( self._url("/api/upload_archive"), files={"archive": (str(id), file)}, timeout=UPLOAD_CODE_REQUEST_TIMEOUT, @@ -152,13 +164,13 @@ def upload_archive(self, id: uuid.UUID, file: Union[BinaryIO, bytes]): resp.raise_for_status() def upload_code(self, file: Union[BinaryIO, bytes]): - resp = requests.post( + resp = self._session.post( self._url("/api/upload_code"), data=file, timeout=UPLOAD_CODE_REQUEST_TIMEOUT ) resp.raise_for_status() def run_job(self) -> Optional[JobInfoResponse]: - resp = requests.post(self._url("/api/run"), timeout=REQUEST_TIMEOUT) + resp = self._session.post(self._url("/api/run"), timeout=REQUEST_TIMEOUT) resp.raise_for_status() if not _is_json_response(resp): # Old runner or runner failed to get job info @@ -166,21 +178,21 @@ def run_job(self) -> Optional[JobInfoResponse]: return JobInfoResponse.__response__.parse_obj(resp.json()) def pull(self, timestamp: int) -> PullResponse: - resp = requests.get( + resp = self._session.get( self._url("/api/pull"), params={"timestamp": timestamp}, timeout=REQUEST_TIMEOUT ) resp.raise_for_status() return PullResponse.__response__.parse_obj(resp.json()) def stop(self): - resp = requests.post(self._url("/api/stop"), timeout=REQUEST_TIMEOUT) + resp = self._session.post(self._url("/api/stop"), timeout=REQUEST_TIMEOUT) resp.raise_for_status() def _url(self, path: str) -> str: - return f"{'https' if self.secure else 'http'}://{self.hostname}:{self.port}/{path.lstrip('/')}" + return f"{self._base_url}/{path.lstrip('/')}" def _healthcheck(self) -> HealthcheckResponse: - resp = requests.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT) + resp = self._session.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT) resp.raise_for_status() return HealthcheckResponse.__response__.parse_obj(resp.json()) @@ -302,11 +314,20 @@ class ShimClient: def __init__( self, - port: int, + port: Optional[int] = None, hostname: str = "localhost", + uds: Optional[PathLike] = None, ): - self._session = requests.Session() - self._base_url = f"http://{hostname}:{port}" + self._session, self._base_url = _make_session_and_base_url(port, hostname, uds) + + @classmethod + def from_address(cls, address: Union[int, Path]) -> Self: + """ + Builds a client from a TCP port (`int`) or a Unix domain socket path (`Path`). + """ + if isinstance(address, int): + return cls(port=address) + return cls(uds=address) # Methods shared by all API versions @@ -626,6 +647,24 @@ def _get_restart_safe_task_statuses(self) -> list[TaskStatus]: return [TaskStatus.TERMINATED] +def _make_session_and_base_url( + port: Optional[int], hostname: str, uds: Optional[PathLike] +) -> tuple[requests.Session, str]: + """ + Builds a session and base URL for HTTP over TCP (`port`) or over + a Unix domain socket (`uds`). Exactly one of the two must be specified. + """ + if (port is None) == (uds is None): + raise ValueError("Either port or uds must be specified, not both") + session = requests.Session() + if uds is not None: + base_url = f"http+unix://{urllib.parse.quote(str(uds), safe='')}" + session.mount("http+unix://", requests_unixsocket.UnixAdapter()) + else: + base_url = f"http://{hostname}:{port}" + return session, base_url + + def healthcheck_response_to_instance_check( response: HealthcheckResponse, instance_health_response: Optional[InstanceHealthResponse] = None, diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 98d5186181..a163176e25 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -1,4 +1,6 @@ import functools +from collections.abc import Mapping +from pathlib import Path from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union import requests @@ -35,7 +37,7 @@ def runner_ssh_tunnel( """ def decorator( - func: Callable[Concatenate[Dict[int, int], P], R], + func: Callable[Concatenate[Mapping[int, int | Path], P], R], ) -> Callable[ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], Union[Literal[False], R], From 086b288a5c0d53e562d99ef236c2c07aff0b4178 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 3 Jun 2026 16:26:58 +0500 Subject: [PATCH 03/25] Update runner and shim client call sites --- .../pipeline_tasks/instances/check.py | 5 +++-- .../background/pipeline_tasks/jobs_running.py | 21 ++++++++++--------- .../pipeline_tasks/jobs_terminating.py | 5 +++-- .../background/scheduled_tasks/metrics.py | 7 ++++--- .../scheduled_tasks/prometheus_metrics.py | 7 +++++-- .../server/services/jobs/__init__.py | 4 ++-- .../server/services/runner/client.py | 7 +++++-- .../_internal/server/services/runner/pool.py | 4 ++-- .../_internal/server/services/runner/ssh.py | 12 +++++------ 9 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index d23d536cd1..d87c3ea613 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -1,5 +1,6 @@ import logging import uuid +from collections.abc import Mapping from datetime import timedelta from typing import Optional @@ -375,13 +376,13 @@ async def _get_backend_for_provisioning_wait( @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) def _check_instance_inner( - ports: dict[int, int], + addresses: Mapping[int, runner_client.LocalAddress], *, instance: InstanceModel, check_instance_health: bool = False, ) -> InstanceCheck: instance_health_response: Optional[InstanceHealthResponse] = None - shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + shim_client = runner_client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) method = shim_client.healthcheck try: healthcheck_response = method(unmask_exceptions=True) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 068add9a63..3319d0d728 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1,6 +1,7 @@ import asyncio import enum import uuid +from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Dict, Iterable, Literal, Optional, Sequence, Union @@ -1310,7 +1311,7 @@ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> boo @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) def _process_provisioning_with_shim( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], run: Run, job_model: JobModel, jrd: Optional[JobRuntimeData], @@ -1322,7 +1323,7 @@ def _process_provisioning_with_shim( ssh_key: Optional[str], ) -> bool: job_spec = get_job_spec(job_model) - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) resp = shim_client.healthcheck() if resp is None: @@ -1436,8 +1437,8 @@ class _SyncShimPullingStateResult: @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) -def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) +def _get_runner_availability(addresses: Mapping[int, client.LocalAddress]) -> _RunnerAvailability: + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) if runner_client.healthcheck() is None: return _RunnerAvailability.UNAVAILABLE return _RunnerAvailability.AVAILABLE @@ -1445,11 +1446,11 @@ def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) def _sync_shim_pulling_state( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], job_model: JobModel, jrd: Optional[JobRuntimeData] = None, ) -> Union[_SyncShimPullingStateResult, Literal[False]]: - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) image_pull_progress: Optional[ImagePullProgress] = None if shim_client.is_api_v2_supported(): task = shim_client.get_task(job_model.id) @@ -1527,7 +1528,7 @@ class _SubmitJobToRunnerResult: @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) def _submit_job_to_runner( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], run: Run, job_model: JobModel, job: Job, @@ -1552,7 +1553,7 @@ def _submit_job_to_runner( else: instance_env = None - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) if runner_client.healthcheck() is None: return _SubmitJobToRunnerResult(success=success_if_not_available) @@ -1597,11 +1598,11 @@ class _ProcessRunningResult: @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) def _process_running( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], run_model: RunModel, job_model: JobModel, ) -> Union[_ProcessRunningResult, Literal[False]]: - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) timestamp = job_model.runner_timestamp or 0 resp = runner_client.pull(timestamp) logs_services.write_logs( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index e15c24db57..6a69f466df 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -1,5 +1,6 @@ import asyncio import uuid +from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Optional, Sequence, TypedDict @@ -853,8 +854,8 @@ async def _stop_container( @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _shim_submit_stop(ports: dict[int, int], job_model: JobModel) -> bool: - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) +def _shim_submit_stop(addresses: Mapping[int, client.LocalAddress], job_model: JobModel) -> bool: + shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) resp = shim_client.healthcheck() if resp is None: diff --git a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py index f75c5f3eae..352383df33 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py @@ -1,6 +1,7 @@ import asyncio import json -from typing import Dict, List, Optional +from collections.abc import Mapping +from typing import List, Optional from sqlalchemy import Delete, delete, select from sqlalchemy.orm import joinedload @@ -166,7 +167,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint] @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) def _pull_runner_metrics( - ports: Dict[int, int], + addresses: Mapping[int, client.LocalAddress], ) -> Optional[MetricsResponse]: - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) return runner_client.get_metrics() diff --git a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py index 5b039fe2ec..fba9e495ff 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py @@ -1,4 +1,5 @@ import uuid +from collections.abc import Mapping from datetime import datetime, timedelta from typing import Optional @@ -145,6 +146,8 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[str]: @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) -def _pull_job_metrics(ports: dict[int, int], task_id: uuid.UUID) -> Optional[str]: - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) +def _pull_job_metrics( + addresses: Mapping[int, client.LocalAddress], task_id: uuid.UUID +) -> Optional[str]: + shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) return shim_client.get_task_metrics(task_id) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 5dc0699113..5e6455184d 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -349,11 +349,11 @@ async def stop_runner(job_model: JobModel, instance_model: InstanceModel): @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) def _stop_runner( - ports: dict[int, int], + addresses: Mapping[int, client.LocalAddress], job_model: JobModel, ): logger.debug("%s: stopping runner", fmt(job_model)) - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) try: runner_client.stop() except requests.RequestException: diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 096e8c75ba..7ccc2b1af7 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -52,6 +52,9 @@ logger = get_logger(__name__) +LocalAddress = Union[int, Path] +"""A local TCP port or a Unix domain socket path the client connects to.""" + class RunnerClient: # `/api/upload_code` call is not required if there is no code @@ -70,7 +73,7 @@ def __init__( self._session, self._base_url = _make_session_and_base_url(port, hostname, uds) @classmethod - def from_address(cls, address: Union[int, Path]) -> Self: + def from_address(cls, address: LocalAddress) -> Self: """ Builds a client from a TCP port (`int`) or a Unix domain socket path (`Path`). """ @@ -321,7 +324,7 @@ def __init__( self._session, self._base_url = _make_session_and_base_url(port, hostname, uds) @classmethod - def from_address(cls, address: Union[int, Path]) -> Self: + def from_address(cls, address: LocalAddress) -> Self: """ Builds a client from a TCP port (`int`) or a Unix domain socket path (`Path`). """ diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 1ea8682624..ee8ebda640 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -47,8 +47,8 @@ def from_jpd( class InstanceConnectionPool: def __init__(self): - self._connections: dict[InstanceConnectionKey, InstanceConnection] - self._access_locks: dict[InstanceConnectionKey, threading.Lock] + self._connections: dict[InstanceConnectionKey, InstanceConnection] = {} + self._access_locks: dict[InstanceConnectionKey, threading.Lock] = {} self._access_locks_lock = threading.Lock() def get_or_open( diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index a163176e25..96796253a8 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -1,7 +1,6 @@ import functools from collections.abc import Mapping -from pathlib import Path -from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union +from typing import Callable, List, Literal, Optional, TypeVar, Union import requests from typing_extensions import Concatenate, ParamSpec @@ -9,20 +8,19 @@ from dstack._internal.core.errors import DstackError, SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.runs import JobProvisioningData, JobRuntimeData -from dstack._internal.server.services.runner.pool import instance_connection_pool +from dstack._internal.server.services.runner.client import LocalAddress +from dstack._internal.server.services.runner.pool import PrivateKeyOrPair, instance_connection_pool from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) P = ParamSpec("P") R = TypeVar("R") -# A host private key or pair of (host private key, optional proxy jump private key) -PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] def runner_ssh_tunnel( ports: List[int], retries: int = 3, retry_interval: float = 1 ) -> Callable[ - [Callable[Concatenate[Dict[int, int], P], R]], + [Callable[Concatenate[Mapping[int, LocalAddress], P], R]], Callable[ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], Union[Literal[False], R], @@ -37,7 +35,7 @@ def runner_ssh_tunnel( """ def decorator( - func: Callable[Concatenate[Mapping[int, int | Path], P], R], + func: Callable[Concatenate[Mapping[int, LocalAddress], P], R], ) -> Callable[ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], Union[Literal[False], R], From 0fee0e3ff1b802b4da81df2cb48dd7dbed02b001 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 3 Jun 2026 16:51:05 +0500 Subject: [PATCH 04/25] Pool fixes --- .../_internal/server/services/runner/pool.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index ee8ebda640..7940f3d062 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -18,8 +18,8 @@ from dstack._internal.server.settings import SERVER_DIR_PATH from dstack._internal.utils.path import FileContent -# A host private key or pair of (host private key, optional proxy jump private key) PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] +"""A host private key or pair of (host private key, optional proxy jump private key)""" CONNECTIONS_DIR = SERVER_DIR_PATH / "instance-connections" @@ -75,12 +75,11 @@ def get_or_open( def drop(self, key: InstanceConnectionKey) -> None: lock = self._get_access_lock(key) with lock: - # close? try: - self._connections.pop(key) - self._access_locks.pop(key) + conn = self._connections.pop(key) except KeyError: - pass + return + conn.close() def close_all(self) -> None: ... # graceful shutdown @@ -110,6 +109,7 @@ def __init__( / f"{self._key.hostname}:{self._key.port}" / str(self._key.ports_to_forward) ) + self._connection_dir.mkdir(parents=True, exist_ok=True) # connection_dir can have a long path that won't be accepted by the ssh command, # so we create a short temporary symlink self._temp_dir, self._connection_symlink_dir = self._init_symlink_dir(self._connection_dir) @@ -117,7 +117,7 @@ def __init__( self._container_to_host_port_map = InstanceConnection._get_container_to_host_port_map( jpd, jrd ) - self._host_port_to_unix_socket_map = InstanceConnection._get_host_port_to_unix_socket_map( + self._host_port_to_uds_map = InstanceConnection._get_host_port_to_uds_map( connection_dir=self._connection_symlink_dir, ports_to_forward=self._key.ports_to_forward, ) @@ -126,7 +126,7 @@ def __init__( port=jpd.ssh_port, identity=_get_identity(ssh_private_key, jpd), control_sock_path=self._control_socket_path, - forwarded_sockets=self._get_forwarded_sockets(self._host_port_to_unix_socket_map), + forwarded_sockets=self._get_forwarded_sockets(self._host_port_to_uds_map), ssh_proxies=_get_proxies(ssh_private_key, jpd), options={ **SSH_DEFAULT_OPTIONS, @@ -140,7 +140,7 @@ def open(self) -> None: self._tunnel.open() def forwarded_path(self, container_port: int) -> Path: - return self._host_port_to_unix_socket_map[self._container_to_host_port_map[container_port]] + return self._host_port_to_uds_map[self._container_to_host_port_map[container_port]] def close(self) -> None: self._tunnel.close() @@ -167,20 +167,20 @@ def _get_container_to_host_port_map( return port_map @staticmethod - def _get_host_port_to_unix_socket_map( + def _get_host_port_to_uds_map( connection_dir: Path, ports_to_forward: Collection[int], ) -> dict[int, Path]: return {port: connection_dir / str(port) for port in ports_to_forward} @staticmethod - def _get_forwarded_sockets(host_port_to_unix_socket_map: dict[int, Path]) -> list[SocketPair]: + def _get_forwarded_sockets(host_port_to_uds_map: dict[int, Path]) -> list[SocketPair]: return [ SocketPair( local=UnixSocket(path=path), remote=IPSocket(host="localhost", port=port), ) - for port, path in host_port_to_unix_socket_map.items() + for port, path in host_port_to_uds_map.items() ] @@ -193,11 +193,16 @@ def _get_identity(ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData) - def _get_proxies( ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData ) -> list[tuple[SSHConnectionParams, FileContent]]: - if not isinstance(ssh_private_key, tuple): + if jpd.ssh_proxy is None: return [] - _, ssh_proxy_private_key = ssh_private_key - if ssh_proxy_private_key is None or jpd.ssh_proxy is None: - return [] + if isinstance(ssh_private_key, str): + ssh_proxy_private_key = ssh_private_key + else: + ssh_proxy_private_key = ssh_private_key[1] + if ssh_proxy_private_key is None: + # In case proxy key is None, fallback to main key (k8s case). + ssh_proxy_private_key = ssh_private_key[0] + proxy_identity = FileContent(ssh_proxy_private_key) return [(jpd.ssh_proxy, proxy_identity)] From bb3ca6add4384c65d8b5f862f93d577a7a0c001c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 3 Jun 2026 16:55:06 +0500 Subject: [PATCH 05/25] Revert args rename --- src/dstack/_internal/server/services/runner/ssh.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 96796253a8..ebd503df14 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -43,8 +43,8 @@ def decorator( @functools.wraps(func) def wrapper( ssh_private_key: PrivateKeyOrPair, - jpd: JobProvisioningData, - jrd: Optional[JobRuntimeData], + job_provisioning_data: JobProvisioningData, + job_runtime_data: Optional[JobRuntimeData], *args: P.args, **kwargs: P.kwargs, ) -> Union[Literal[False], R]: @@ -52,13 +52,17 @@ def wrapper( Returns: is successful """ - if jpd.backend == BackendType.LOCAL: + if job_provisioning_data.backend == BackendType.LOCAL: # without SSH container_ports_map = {port: port for port in ports} return func(container_ports_map, *args, **kwargs) for attempt in range(2): # cached, then one fresh reopen - conn = instance_connection_pool.get_or_open(ssh_private_key, jpd, jrd) + conn = instance_connection_pool.get_or_open( + ssh_private_key=ssh_private_key, + jpd=job_provisioning_data, + jrd=job_runtime_data, + ) if conn is None: return False # couldn't establish at all sock_paths = {p: conn.forwarded_path(p) for p in ports} From cd80683665932117990241784f366cdca2a7f6da Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 10:51:01 +0500 Subject: [PATCH 06/25] Skip pool for container backends --- .../server/services/gateways/connection.py | 17 +++----- .../_internal/server/services/runner/pool.py | 42 +++++++++++-------- .../_internal/server/services/runner/ssh.py | 33 +++++++++++++-- src/dstack/_internal/utils/path.py | 10 +++++ 4 files changed, 68 insertions(+), 34 deletions(-) diff --git a/src/dstack/_internal/server/services/gateways/connection.py b/src/dstack/_internal/server/services/gateways/connection.py index b8df322a1d..dada5bea64 100644 --- a/src/dstack/_internal/server/services/gateways/connection.py +++ b/src/dstack/_internal/server/services/gateways/connection.py @@ -1,9 +1,7 @@ import contextlib import shutil import uuid -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import AsyncIterator, Optional, Tuple +from typing import AsyncIterator, Optional import aiorwlock @@ -22,7 +20,7 @@ from dstack._internal.server.services.gateways.client import GatewayClient from dstack._internal.server.settings import SERVER_DIR_PATH from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.path import FileContent +from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir logger = get_logger(__name__) CONNECTIONS_DIR = SERVER_DIR_PATH / "gateway-connections" @@ -47,7 +45,9 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int): self.connection_dir = CONNECTIONS_DIR / ip_address # connection_dir can have a long path that won't be accepted by the ssh command, # so we create a short temporary symlink - self.temp_dir, self.connection_symlink_dir = self._init_symlink_dir(self.connection_dir) + self.temp_dir, self.connection_symlink_dir = make_tmp_symlink_to_dir( + self.connection_dir, "connection" + ) self.gateway_socket_path = self.connection_symlink_dir / "gateway.sock" self.tunnel = SSHTunnel( destination=f"ubuntu@{ip_address}", @@ -69,13 +69,6 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int): self.tunnel_id = uuid.uuid4() self._client = GatewayClient(uds=str(self.gateway_socket_path)) - @staticmethod - def _init_symlink_dir(connection_dir: Path) -> Tuple[TemporaryDirectory, Path]: - temp_dir = TemporaryDirectory() - symlink_dir = Path(temp_dir.name) / "connection" - symlink_dir.symlink_to(connection_dir, target_is_directory=True) - return temp_dir, symlink_dir - async def check_or_restart(self) -> bool: async with self._lock.writer_lock: if not await self.tunnel.acheck(): diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 7940f3d062..66d18201f7 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -16,7 +16,7 @@ UnixSocket, ) from dstack._internal.server.settings import SERVER_DIR_PATH -from dstack._internal.utils.path import FileContent +from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] """A host private key or pair of (host private key, optional proxy jump private key)""" @@ -45,6 +45,8 @@ def from_jpd( ) +# InstanceConnectionPool has sync interface because runner/shim clients and all the callers are sync. +# TODO: Consider moving all of them to async for consistency with other pools/clients. class InstanceConnectionPool: def __init__(self): self._connections: dict[InstanceConnectionKey, InstanceConnection] = {} @@ -102,23 +104,19 @@ def __init__( ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData, jrd: Optional[JobRuntimeData], + ephemeral: bool = False, ) -> None: self._key = InstanceConnectionKey.from_jpd(jpd, jrd) - self._connection_dir = ( - CONNECTIONS_DIR - / f"{self._key.hostname}:{self._key.port}" - / str(self._key.ports_to_forward) + self._ephemeral = ephemeral + self._temp_dir, self._effective_conn_dir = InstanceConnection._resolve_conn_dir( + self._key, ephemeral ) - self._connection_dir.mkdir(parents=True, exist_ok=True) - # connection_dir can have a long path that won't be accepted by the ssh command, - # so we create a short temporary symlink - self._temp_dir, self._connection_symlink_dir = self._init_symlink_dir(self._connection_dir) - self._control_socket_path = self._connection_symlink_dir / "control.sock" + self._control_socket_path = self._effective_conn_dir / "control.sock" self._container_to_host_port_map = InstanceConnection._get_container_to_host_port_map( jpd, jrd ) self._host_port_to_uds_map = InstanceConnection._get_host_port_to_uds_map( - connection_dir=self._connection_symlink_dir, + conn_dir=self._effective_conn_dir, ports_to_forward=self._key.ports_to_forward, ) self._tunnel = SSHTunnel( @@ -150,11 +148,19 @@ def key(self) -> InstanceConnectionKey: return self._key @staticmethod - def _init_symlink_dir(connection_dir: Path) -> tuple[TemporaryDirectory, Path]: - temp_dir = TemporaryDirectory() - symlink_dir = Path(temp_dir.name) / "connection" - symlink_dir.symlink_to(connection_dir, target_is_directory=True) - return temp_dir, symlink_dir + def _resolve_conn_dir( + key: InstanceConnectionKey, ephemeral: bool + ) -> tuple[TemporaryDirectory, Path]: + if ephemeral: + temp_dir = TemporaryDirectory() + return temp_dir, Path(temp_dir.name) / "connection" + + conn_dir = CONNECTIONS_DIR / f"{key.hostname}:{key.port}" / str(key.ports_to_forward) + conn_dir.mkdir(parents=True, exist_ok=True) + # Connection_dir can have a long path that won't be accepted by the ssh command, + # so we create a short temporary symlink. + temp_dir, conn_symlink_dir = make_tmp_symlink_to_dir(conn_dir, "connection") + return temp_dir, conn_symlink_dir @staticmethod def _get_container_to_host_port_map( @@ -168,10 +174,10 @@ def _get_container_to_host_port_map( @staticmethod def _get_host_port_to_uds_map( - connection_dir: Path, + conn_dir: Path, ports_to_forward: Collection[int], ) -> dict[int, Path]: - return {port: connection_dir / str(port) for port in ports_to_forward} + return {port: conn_dir / str(port) for port in ports_to_forward} @staticmethod def _get_forwarded_sockets(host_port_to_uds_map: dict[int, Path]) -> list[SocketPair]: diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index ebd503df14..db7a604912 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -9,7 +9,11 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.runs import JobProvisioningData, JobRuntimeData from dstack._internal.server.services.runner.client import LocalAddress -from dstack._internal.server.services.runner.pool import PrivateKeyOrPair, instance_connection_pool +from dstack._internal.server.services.runner.pool import ( + InstanceConnection, + PrivateKeyOrPair, + instance_connection_pool, +) from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -27,7 +31,7 @@ def runner_ssh_tunnel( ], ]: """ - A decorator that opens an SSH tunnel to the runner. + A decorator that opens an SSH tunnel to the runner instance for port forwarding. NOTE: connections from dstack-server to running jobs are expected to be short. The runner uses a heuristic to differentiate dstack-server connections from @@ -57,6 +61,28 @@ def wrapper( container_ports_map = {port: port for port in ports} return func(container_ports_map, *args, **kwargs) + if not job_provisioning_data.dockerized: + # Connections from dstack-server to runner's sshd are expected to be short + # as the `inactivity_duration` feature distinguishes user and server connections based on duration. + # Do not re-use SSH connections for container-based backends. + # TODO: Drop `inactivity_duration` dependence on connection duration and re-use connections. + conn = InstanceConnection( + ssh_private_key=ssh_private_key, + jpd=job_provisioning_data, + jrd=job_runtime_data, + ephemeral=True, + ) + try: + conn.open() + except SSHError: + return False + try: + return func({p: conn.forwarded_path(p) for p in ports}, *args, **kwargs) + except (DstackError, requests.RequestException): + return False + finally: + conn.close() + for attempt in range(2): # cached, then one fresh reopen conn = instance_connection_pool.get_or_open( ssh_private_key=ssh_private_key, @@ -65,9 +91,8 @@ def wrapper( ) if conn is None: return False # couldn't establish at all - sock_paths = {p: conn.forwarded_path(p) for p in ports} try: - return func(sock_paths, *args, **kwargs) + return func({p: conn.forwarded_path(p) for p in ports}, *args, **kwargs) except (SSHError, requests.ConnectionError): instance_connection_pool.drop(conn.key) # dead ssh connection, re-open except (DstackError, requests.RequestException): diff --git a/src/dstack/_internal/utils/path.py b/src/dstack/_internal/utils/path.py index 18e0b7c812..07b8fdd664 100644 --- a/src/dstack/_internal/utils/path.py +++ b/src/dstack/_internal/utils/path.py @@ -1,6 +1,7 @@ import os from dataclasses import dataclass from pathlib import Path, PurePath, PurePosixPath +from tempfile import TemporaryDirectory from typing import Union PathLike = Union[str, os.PathLike] @@ -55,3 +56,12 @@ def is_absolute_posix_path(path: PathLike) -> bool: if str(path).startswith("~"): return True return PurePosixPath(path).is_absolute() + + +def make_tmp_symlink_to_dir( + dirpath: PathLike, symlink_dirname: str +) -> tuple[TemporaryDirectory, Path]: + temp_dir = TemporaryDirectory() + symlink_dir = Path(temp_dir.name) / symlink_dirname + symlink_dir.symlink_to(dirpath, target_is_directory=True) + return temp_dir, symlink_dir From 4fd472285862609d1b9015724b65b9be5417029d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 11:13:22 +0500 Subject: [PATCH 07/25] Use dstack tmp dir --- src/dstack/_internal/server/services/runner/pool.py | 11 ++++++++--- src/dstack/_internal/server/settings.py | 5 +++++ src/dstack/_internal/utils/path.py | 6 +++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 66d18201f7..fded8acb68 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -15,7 +15,7 @@ SSHTunnel, UnixSocket, ) -from dstack._internal.server.settings import SERVER_DIR_PATH +from dstack._internal.server.settings import SERVER_DIR_PATH, SERVER_TMP_PATH from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] @@ -153,13 +153,18 @@ def _resolve_conn_dir( ) -> tuple[TemporaryDirectory, Path]: if ephemeral: temp_dir = TemporaryDirectory() - return temp_dir, Path(temp_dir.name) / "connection" + return temp_dir, Path(str(temp_dir)) conn_dir = CONNECTIONS_DIR / f"{key.hostname}:{key.port}" / str(key.ports_to_forward) conn_dir.mkdir(parents=True, exist_ok=True) # Connection_dir can have a long path that won't be accepted by the ssh command, # so we create a short temporary symlink. - temp_dir, conn_symlink_dir = make_tmp_symlink_to_dir(conn_dir, "connection") + temp_dir, conn_symlink_dir = make_tmp_symlink_to_dir( + dirpath=conn_dir, + symlink_dirname="connection", + # Using dstack's own tmp dir to avoid age-based tmp cleanup. + base_dir=SERVER_TMP_PATH, + ) return temp_dir, conn_symlink_dir @staticmethod diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index f90aee339d..8b1c4bb11e 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -20,6 +20,11 @@ SERVER_DATA_DIR_PATH = SERVER_DIR_PATH / "data" SERVER_DATA_DIR_PATH.mkdir(parents=True, exist_ok=True) + +SERVER_TMP_PATH = SERVER_DIR_PATH / "tmp" +"""SERVER_TMP_PATH can be used as dstack's own /tmp when age-based cleaning for /tmp is not desirable""" +SERVER_TMP_PATH.mkdir(parents=True, exist_ok=True) + DATABASE_URL = os.getenv( "DSTACK_DATABASE_URL", f"sqlite+aiosqlite:///{str(SERVER_DATA_DIR_PATH.absolute())}/sqlite.db" ) diff --git a/src/dstack/_internal/utils/path.py b/src/dstack/_internal/utils/path.py index 07b8fdd664..5131831cf8 100644 --- a/src/dstack/_internal/utils/path.py +++ b/src/dstack/_internal/utils/path.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path, PurePath, PurePosixPath from tempfile import TemporaryDirectory -from typing import Union +from typing import Optional, Union PathLike = Union[str, os.PathLike] @@ -59,9 +59,9 @@ def is_absolute_posix_path(path: PathLike) -> bool: def make_tmp_symlink_to_dir( - dirpath: PathLike, symlink_dirname: str + dirpath: PathLike, symlink_dirname: str, base_dir: Optional[PathLike] = None ) -> tuple[TemporaryDirectory, Path]: - temp_dir = TemporaryDirectory() + temp_dir = TemporaryDirectory(dir=base_dir) symlink_dir = Path(temp_dir.name) / symlink_dirname symlink_dir.symlink_to(dirpath, target_is_directory=True) return temp_dir, symlink_dir From 8794616a1826471dde8a070571659d33688e9088 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 11:53:14 +0500 Subject: [PATCH 08/25] Drop ports from runner_ssh_tunnel --- .../pipeline_tasks/instances/check.py | 2 +- .../background/pipeline_tasks/jobs_running.py | 10 +-- .../pipeline_tasks/jobs_terminating.py | 2 +- .../background/scheduled_tasks/metrics.py | 2 +- .../scheduled_tasks/prometheus_metrics.py | 2 +- .../server/services/jobs/__init__.py | 2 +- .../_internal/server/services/runner/pool.py | 53 +++++++++------ .../_internal/server/services/runner/ssh.py | 64 +++++++++++-------- 8 files changed, 83 insertions(+), 54 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index d87c3ea613..b0ab28829a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -374,7 +374,7 @@ async def _get_backend_for_provisioning_wait( ) -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +@runner_ssh_tunnel(retries=1) def _check_instance_inner( addresses: Mapping[int, runner_client.LocalAddress], *, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 3319d0d728..00e9dff92d 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1309,7 +1309,7 @@ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> boo return False -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +@runner_ssh_tunnel(retries=1) def _process_provisioning_with_shim( addresses: Mapping[int, client.LocalAddress], run: Run, @@ -1436,7 +1436,7 @@ class _SyncShimPullingStateResult: image_pull_progress: Optional[ImagePullProgress] = None -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +@runner_ssh_tunnel(retries=1) def _get_runner_availability(addresses: Mapping[int, client.LocalAddress]) -> _RunnerAvailability: runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) if runner_client.healthcheck() is None: @@ -1444,7 +1444,7 @@ def _get_runner_availability(addresses: Mapping[int, client.LocalAddress]) -> _R return _RunnerAvailability.AVAILABLE -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +@runner_ssh_tunnel() def _sync_shim_pulling_state( addresses: Mapping[int, client.LocalAddress], job_model: JobModel, @@ -1526,7 +1526,7 @@ class _SubmitJobToRunnerResult: job_runtime_data: Optional[JobRuntimeData] = None -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +@runner_ssh_tunnel(retries=1) def _submit_job_to_runner( addresses: Mapping[int, client.LocalAddress], run: Run, @@ -1596,7 +1596,7 @@ class _ProcessRunningResult: job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) +@runner_ssh_tunnel() def _process_running( addresses: Mapping[int, client.LocalAddress], run_model: RunModel, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 6a69f466df..e2d19e341f 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -853,7 +853,7 @@ async def _stop_container( return True -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +@runner_ssh_tunnel() def _shim_submit_stop(addresses: Mapping[int, client.LocalAddress], job_model: JobModel) -> bool: shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py index 352383df33..4540ddc1bd 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py @@ -165,7 +165,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint] ) -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +@runner_ssh_tunnel(retries=1) def _pull_runner_metrics( addresses: Mapping[int, client.LocalAddress], ) -> Optional[MetricsResponse]: diff --git a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py index fba9e495ff..7acd85b8f4 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py @@ -145,7 +145,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[str]: return res -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +@runner_ssh_tunnel(retries=1) def _pull_job_metrics( addresses: Mapping[int, client.LocalAddress], task_id: uuid.UUID ) -> Optional[str]: diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 5e6455184d..54d882a736 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -347,7 +347,7 @@ async def stop_runner(job_model: JobModel, instance_model: InstanceModel): logger.debug("%s: failed to stop runner", fmt(job_model)) -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) +@runner_ssh_tunnel() def _stop_runner( addresses: Mapping[int, client.LocalAddress], job_model: JobModel, diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index fded8acb68..4a6c376290 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -23,8 +23,6 @@ CONNECTIONS_DIR = SERVER_DIR_PATH / "instance-connections" -DEFAULT_PORTS_TO_FORWARD = [DSTACK_SHIM_HTTP_PORT, DSTACK_RUNNER_HTTP_PORT] - @dataclass(frozen=True) class InstanceConnectionKey: @@ -37,7 +35,7 @@ def from_jpd( jpd: JobProvisioningData, jrd: Optional[JobRuntimeData] ) -> "InstanceConnectionKey": assert jpd.hostname is not None and jpd.ssh_port is not None - container_to_host_port_map = InstanceConnection._get_container_to_host_port_map(jpd, jrd) + container_to_host_port_map = get_container_to_host_port_map(jpd, jrd) return InstanceConnectionKey( hostname=jpd.hostname, port=jpd.ssh_port, @@ -48,6 +46,11 @@ def from_jpd( # InstanceConnectionPool has sync interface because runner/shim clients and all the callers are sync. # TODO: Consider moving all of them to async for consistency with other pools/clients. class InstanceConnectionPool: + """ + A pool of SSH connections to instances' host sshd (VM-based) + or runner sshd (container-based) for forwarding shim and runner ports. + """ + def __init__(self): self._connections: dict[InstanceConnectionKey, InstanceConnection] = {} self._access_locks: dict[InstanceConnectionKey, threading.Lock] = {} @@ -106,15 +109,20 @@ def __init__( jrd: Optional[JobRuntimeData], ephemeral: bool = False, ) -> None: + """ + An SSH connection to instance's host sshd (VM-based) + or runner sshd (container-based) for forwarding shim and runner ports. + + Args: + ephemeral: Creates a unique tmp dir for the uds. Use when connection re-use is not needed. + """ self._key = InstanceConnectionKey.from_jpd(jpd, jrd) self._ephemeral = ephemeral self._temp_dir, self._effective_conn_dir = InstanceConnection._resolve_conn_dir( self._key, ephemeral ) self._control_socket_path = self._effective_conn_dir / "control.sock" - self._container_to_host_port_map = InstanceConnection._get_container_to_host_port_map( - jpd, jrd - ) + self._container_to_host_port_map = get_container_to_host_port_map(jpd, jrd) self._host_port_to_uds_map = InstanceConnection._get_host_port_to_uds_map( conn_dir=self._effective_conn_dir, ports_to_forward=self._key.ports_to_forward, @@ -137,8 +145,12 @@ def __init__( def open(self) -> None: self._tunnel.open() - def forwarded_path(self, container_port: int) -> Path: - return self._host_port_to_uds_map[self._container_to_host_port_map[container_port]] + def forwarded_paths(self) -> dict[int, Path]: + """Returns a mapping from container port to the local UDS path.""" + return { + container_port: self._host_port_to_uds_map[host_port] + for container_port, host_port in self._container_to_host_port_map.items() + } def close(self) -> None: self._tunnel.close() @@ -153,7 +165,7 @@ def _resolve_conn_dir( ) -> tuple[TemporaryDirectory, Path]: if ephemeral: temp_dir = TemporaryDirectory() - return temp_dir, Path(str(temp_dir)) + return temp_dir, Path(temp_dir.name) conn_dir = CONNECTIONS_DIR / f"{key.hostname}:{key.port}" / str(key.ports_to_forward) conn_dir.mkdir(parents=True, exist_ok=True) @@ -167,16 +179,6 @@ def _resolve_conn_dir( ) return temp_dir, conn_symlink_dir - @staticmethod - def _get_container_to_host_port_map( - jpd: JobProvisioningData, - jrd: Optional[JobRuntimeData], - ) -> dict[int, int]: - port_map = {port: port for port in DEFAULT_PORTS_TO_FORWARD} - if jrd is not None and jrd.ports is not None: - port_map.update(jrd.ports) - return port_map - @staticmethod def _get_host_port_to_uds_map( conn_dir: Path, @@ -195,6 +197,19 @@ def _get_forwarded_sockets(host_port_to_uds_map: dict[int, Path]) -> list[Socket ] +def get_container_to_host_port_map( + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], +) -> dict[int, int]: + runner_host_port = DSTACK_RUNNER_HTTP_PORT + if jrd is not None and jrd.ports is not None: + runner_host_port = jrd.ports.get(DSTACK_RUNNER_HTTP_PORT, runner_host_port) + port_map = {DSTACK_RUNNER_HTTP_PORT: runner_host_port} + if jpd.dockerized: + port_map[DSTACK_SHIM_HTTP_PORT] = DSTACK_SHIM_HTTP_PORT + return port_map + + def _get_identity(ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData) -> FileContent: if isinstance(ssh_private_key, tuple): ssh_private_key, _ = ssh_private_key diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index db7a604912..609c7fbe50 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -1,6 +1,7 @@ import functools +import time from collections.abc import Mapping -from typing import Callable, List, Literal, Optional, TypeVar, Union +from typing import Callable, Literal, Optional, TypeVar, Union import requests from typing_extensions import Concatenate, ParamSpec @@ -12,6 +13,7 @@ from dstack._internal.server.services.runner.pool import ( InstanceConnection, PrivateKeyOrPair, + get_container_to_host_port_map, instance_connection_pool, ) from dstack._internal.utils.logging import get_logger @@ -22,7 +24,7 @@ def runner_ssh_tunnel( - ports: List[int], retries: int = 3, retry_interval: float = 1 + retries: int = 3, retry_interval: float = 1 ) -> Callable[ [Callable[Concatenate[Mapping[int, LocalAddress], P], R]], Callable[ @@ -33,9 +35,15 @@ def runner_ssh_tunnel( """ A decorator that opens an SSH tunnel to the runner instance for port forwarding. - NOTE: connections from dstack-server to running jobs are expected to be short. - The runner uses a heuristic to differentiate dstack-server connections from - client connections based on their duration. See `ConnectionTracker` for details. + Forwarded ports: + * VM-based backends: forward the shim and runner ports. + * Container-based backends: forward only the runner port. + * `jrd.ports` may remap the runner port (blocks case). + + Always forwards the same ports for the given instance/job so that connection is reused across all calls. + In case of blocks, each job uses a separate connection as the runner host port differs. + + `retries` and `retry_interval` apply only if connection pooling is not used. """ def decorator( @@ -58,32 +66,38 @@ def wrapper( """ if job_provisioning_data.backend == BackendType.LOCAL: # without SSH - container_ports_map = {port: port for port in ports} - return func(container_ports_map, *args, **kwargs) + port_map = get_container_to_host_port_map(job_provisioning_data, job_runtime_data) + return func(port_map, *args, **kwargs) if not job_provisioning_data.dockerized: # Connections from dstack-server to runner's sshd are expected to be short # as the `inactivity_duration` feature distinguishes user and server connections based on duration. # Do not re-use SSH connections for container-based backends. # TODO: Drop `inactivity_duration` dependence on connection duration and re-use connections. - conn = InstanceConnection( - ssh_private_key=ssh_private_key, - jpd=job_provisioning_data, - jrd=job_runtime_data, - ephemeral=True, - ) - try: - conn.open() - except SSHError: - return False - try: - return func({p: conn.forwarded_path(p) for p in ports}, *args, **kwargs) - except (DstackError, requests.RequestException): - return False - finally: - conn.close() + for attempt in range(retries): + if attempt > 0: + time.sleep(retry_interval) + conn = InstanceConnection( + ssh_private_key=ssh_private_key, + jpd=job_provisioning_data, + jrd=job_runtime_data, + ephemeral=True, + ) + try: + conn.open() + except SSHError: + continue + try: + return func(conn.forwarded_paths(), *args, **kwargs) + except (SSHError, requests.ConnectionError): + continue # connection-level failure, retry with a fresh connection + except (DstackError, requests.RequestException): + return False + finally: + conn.close() + return False - for attempt in range(2): # cached, then one fresh reopen + for _ in range(2): # cached, then one fresh reopen conn = instance_connection_pool.get_or_open( ssh_private_key=ssh_private_key, jpd=job_provisioning_data, @@ -92,7 +106,7 @@ def wrapper( if conn is None: return False # couldn't establish at all try: - return func({p: conn.forwarded_path(p) for p in ports}, *args, **kwargs) + return func(conn.forwarded_paths(), *args, **kwargs) except (SSHError, requests.ConnectionError): instance_connection_pool.drop(conn.key) # dead ssh connection, re-open except (DstackError, requests.RequestException): From eef2608a62ca2e6277cec8d408d07a05d236fc6c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 11:57:44 +0500 Subject: [PATCH 09/25] Refactor methods --- .../_internal/server/services/runner/pool.py | 80 ++++++++++--------- .../_internal/server/services/runner/ssh.py | 5 +- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 4a6c376290..b38ee9653c 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -35,7 +35,7 @@ def from_jpd( jpd: JobProvisioningData, jrd: Optional[JobRuntimeData] ) -> "InstanceConnectionKey": assert jpd.hostname is not None and jpd.ssh_port is not None - container_to_host_port_map = get_container_to_host_port_map(jpd, jrd) + container_to_host_port_map = InstanceConnection.get_container_to_host_port_map(jpd, jrd) return InstanceConnectionKey( hostname=jpd.hostname, port=jpd.ssh_port, @@ -122,7 +122,9 @@ def __init__( self._key, ephemeral ) self._control_socket_path = self._effective_conn_dir / "control.sock" - self._container_to_host_port_map = get_container_to_host_port_map(jpd, jrd) + self._container_to_host_port_map = InstanceConnection.get_container_to_host_port_map( + jpd, jrd + ) self._host_port_to_uds_map = InstanceConnection._get_host_port_to_uds_map( conn_dir=self._effective_conn_dir, ports_to_forward=self._key.ports_to_forward, @@ -130,10 +132,10 @@ def __init__( self._tunnel = SSHTunnel( destination=f"{jpd.username}@{jpd.hostname}", port=jpd.ssh_port, - identity=_get_identity(ssh_private_key, jpd), + identity=InstanceConnection._get_identity(ssh_private_key, jpd), control_sock_path=self._control_socket_path, forwarded_sockets=self._get_forwarded_sockets(self._host_port_to_uds_map), - ssh_proxies=_get_proxies(ssh_private_key, jpd), + ssh_proxies=InstanceConnection._get_proxies(ssh_private_key, jpd), options={ **SSH_DEFAULT_OPTIONS, "ServerAliveInterval": "30", @@ -159,6 +161,19 @@ def close(self) -> None: def key(self) -> InstanceConnectionKey: return self._key + @staticmethod + def get_container_to_host_port_map( + jpd: JobProvisioningData, + jrd: Optional[JobRuntimeData], + ) -> dict[int, int]: + runner_host_port = DSTACK_RUNNER_HTTP_PORT + if jrd is not None and jrd.ports is not None: + runner_host_port = jrd.ports.get(DSTACK_RUNNER_HTTP_PORT, runner_host_port) + port_map = {DSTACK_RUNNER_HTTP_PORT: runner_host_port} + if jpd.dockerized: + port_map[DSTACK_SHIM_HTTP_PORT] = DSTACK_SHIM_HTTP_PORT + return port_map + @staticmethod def _resolve_conn_dir( key: InstanceConnectionKey, ephemeral: bool @@ -196,39 +211,26 @@ def _get_forwarded_sockets(host_port_to_uds_map: dict[int, Path]) -> list[Socket for port, path in host_port_to_uds_map.items() ] + @staticmethod + def _get_identity(ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData) -> FileContent: + if isinstance(ssh_private_key, tuple): + ssh_private_key, _ = ssh_private_key + return FileContent(ssh_private_key) -def get_container_to_host_port_map( - jpd: JobProvisioningData, - jrd: Optional[JobRuntimeData], -) -> dict[int, int]: - runner_host_port = DSTACK_RUNNER_HTTP_PORT - if jrd is not None and jrd.ports is not None: - runner_host_port = jrd.ports.get(DSTACK_RUNNER_HTTP_PORT, runner_host_port) - port_map = {DSTACK_RUNNER_HTTP_PORT: runner_host_port} - if jpd.dockerized: - port_map[DSTACK_SHIM_HTTP_PORT] = DSTACK_SHIM_HTTP_PORT - return port_map - - -def _get_identity(ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData) -> FileContent: - if isinstance(ssh_private_key, tuple): - ssh_private_key, _ = ssh_private_key - return FileContent(ssh_private_key) - - -def _get_proxies( - ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData -) -> list[tuple[SSHConnectionParams, FileContent]]: - if jpd.ssh_proxy is None: - return [] - - if isinstance(ssh_private_key, str): - ssh_proxy_private_key = ssh_private_key - else: - ssh_proxy_private_key = ssh_private_key[1] - if ssh_proxy_private_key is None: - # In case proxy key is None, fallback to main key (k8s case). - ssh_proxy_private_key = ssh_private_key[0] - - proxy_identity = FileContent(ssh_proxy_private_key) - return [(jpd.ssh_proxy, proxy_identity)] + @staticmethod + def _get_proxies( + ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData + ) -> list[tuple[SSHConnectionParams, FileContent]]: + if jpd.ssh_proxy is None: + return [] + + if isinstance(ssh_private_key, str): + ssh_proxy_private_key = ssh_private_key + else: + ssh_proxy_private_key = ssh_private_key[1] + if ssh_proxy_private_key is None: + # In case proxy key is None, fallback to main key (k8s case). + ssh_proxy_private_key = ssh_private_key[0] + + proxy_identity = FileContent(ssh_proxy_private_key) + return [(jpd.ssh_proxy, proxy_identity)] diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 609c7fbe50..50848fb225 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -13,7 +13,6 @@ from dstack._internal.server.services.runner.pool import ( InstanceConnection, PrivateKeyOrPair, - get_container_to_host_port_map, instance_connection_pool, ) from dstack._internal.utils.logging import get_logger @@ -66,7 +65,9 @@ def wrapper( """ if job_provisioning_data.backend == BackendType.LOCAL: # without SSH - port_map = get_container_to_host_port_map(job_provisioning_data, job_runtime_data) + port_map = InstanceConnection.get_container_to_host_port_map( + job_provisioning_data, job_runtime_data + ) return func(port_map, *args, **kwargs) if not job_provisioning_data.dockerized: From 8558adf20b5c9d6ed994e2ce802c29c3c1bf4639 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 12:16:21 +0500 Subject: [PATCH 10/25] Implement InstanceConnectionPool.close_all --- src/dstack/_internal/server/app.py | 3 +++ .../_internal/server/services/runner/pool.py | 24 +++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 7de6e74059..ef8c9e6b54 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -57,6 +57,7 @@ from dstack._internal.server.services.projects import get_or_create_default_project from dstack._internal.server.services.proxy.deps import ServerProxyDependencyInjector from dstack._internal.server.services.proxy.routers import service_proxy +from dstack._internal.server.services.runner.pool import instance_connection_pool from dstack._internal.server.services.storage import init_default_storage from dstack._internal.server.services.users import get_or_create_admin_user from dstack._internal.server.settings import ( @@ -75,6 +76,7 @@ get_client_version, get_server_client_error_details, ) +from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import check_required_ssh_version @@ -209,6 +211,7 @@ async def lifespan(app: FastAPI): await gateway_connections_pool.remove_all() service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() await service_conn_pool.remove_all() + await run_async(instance_connection_pool.close_all) await get_db().engine.dispose() # Let checked-out DB connections close as dispose() only closes checked-in connections await asyncio.sleep(3) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index b38ee9653c..e1d0ae8b9f 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -16,8 +16,11 @@ UnixSocket, ) from dstack._internal.server.settings import SERVER_DIR_PATH, SERVER_TMP_PATH +from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir +logger = get_logger(__name__) + PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] """A host private key or pair of (host private key, optional proxy jump private key)""" @@ -55,6 +58,7 @@ def __init__(self): self._connections: dict[InstanceConnectionKey, InstanceConnection] = {} self._access_locks: dict[InstanceConnectionKey, threading.Lock] = {} self._access_locks_lock = threading.Lock() + self._closed = False def get_or_open( self, @@ -65,6 +69,8 @@ def get_or_open( key = InstanceConnectionKey.from_jpd(jpd, jrd) lock = self._get_access_lock(key) with lock: + if self._closed: + return None conn = self._connections.get(key) if conn is not None: return conn @@ -84,9 +90,23 @@ def drop(self, key: InstanceConnectionKey) -> None: conn = self._connections.pop(key) except KeyError: return - conn.close() + try: + conn.close() + except Exception: + logger.exception("Failed to close instance connection %s", key) - def close_all(self) -> None: ... # graceful shutdown + def close_all(self) -> None: + """ + Closes all connections and prevents new ones from being opened. + Safe to call concurrently with in-flight `get_or_open` calls. + `get_or_open` returns `None` after `close_all`. + """ + with self._access_locks_lock: + self._closed = True + keys = list(self._access_locks.keys()) + logger.debug("Closing %d instance connection(s)", len(keys)) + for key in keys: + self.drop(key) def _get_access_lock(self, key: InstanceConnectionKey) -> threading.Lock: with self._access_locks_lock: From 1a0dd6c2a541f54ea6172e7249e2451aa7ad3586 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 14:40:28 +0500 Subject: [PATCH 11/25] Add DSTACK_SERVER_SSH_POOL_DISABLED --- mkdocs/docs/reference/env.md | 1 + .../_internal/server/services/runner/pool.py | 23 +++++++++++++++---- .../_internal/server/services/runner/ssh.py | 3 ++- src/dstack/_internal/server/settings.py | 2 ++ 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/mkdocs/docs/reference/env.md b/mkdocs/docs/reference/env.md index 1b81629109..c88ccce4c5 100644 --- a/mkdocs/docs/reference/env.md +++ b/mkdocs/docs/reference/env.md @@ -141,6 +141,7 @@ For more details on the options below, refer to the [server deployment](../guide - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY } – A default Docker registry to use for job images that do not specify an explicit registry. E.g., if set to `registry.example`, then `image: ubuntu` becomes equivalent to `image: registry.example/ubuntu`. **Note**: This setting should only be used for configuring registries that act as a pull-through cache for Docker Hub. The default `dstack` images are also pulled from the configured registry. - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME } – Username for authenticating with the default Docker registry. See `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`. - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD } – Password for authenticating with the default Docker registry. Applied only when the image has no explicit registry and the run configuration does not specify `registry_auth`. **Note**: The value may be visible to anyone who can SSH into instances managed by `dstack`, which usually includes all users of that `dstack` server. +- `DSTACK_SERVER_SSH_POOL_DISABLED`{ #DSTACK_SERVER_SSH_POOL_DISABLED } – Disables the reuse of server-instance SSH connections. By default, SHH connections are reused, and each active instance may consume ~2-10MB of server RAM. Set this to save RAM at the expense of opening an SSH connection on every server-instance communication. ??? info "Internal environment variables" The following environment variables are intended for development purposes: diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index e1d0ae8b9f..3c438b67b9 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -52,6 +52,10 @@ class InstanceConnectionPool: """ A pool of SSH connections to instances' host sshd (VM-based) or runner sshd (container-based) for forwarding shim and runner ports. + + NOTE: The pool does not currently intended for arbitrary ports forwarding, only for shim and runner ports. + E.g. it cannot be used to forward services ports for probes or router-worker communication. + This simplified model allows forwarding the same ports for the given host:port and reusing the connection across all calls. """ def __init__(self): @@ -98,8 +102,8 @@ def drop(self, key: InstanceConnectionKey) -> None: def close_all(self) -> None: """ Closes all connections and prevents new ones from being opened. - Safe to call concurrently with in-flight `get_or_open` calls. - `get_or_open` returns `None` after `close_all`. + Safe to call concurrently with in-flight `get_or_open()` calls. + `get_or_open()` will return `None` after `close_all()`. """ with self._access_locks_lock: self._closed = True @@ -133,8 +137,13 @@ def __init__( An SSH connection to instance's host sshd (VM-based) or runner sshd (container-based) for forwarding shim and runner ports. + The same control socket is used for all connections to the same hostname:port, + unless jrd overrides the runner port mapped on host (blocks case). + In case of blocks, each job establishes a separate connection with a different runner port forwarded. + TODO: Re-use the same SSH connection for all blocks via `-O forward`/`-O cancel`. + Args: - ephemeral: Creates a unique tmp dir for the uds. Use when connection re-use is not needed. + ephemeral: Creates a unique tmp dir for the UDS. Use when connection re-use is not needed. """ self._key = InstanceConnectionKey.from_jpd(jpd, jrd) self._ephemeral = ephemeral @@ -159,6 +168,8 @@ def __init__( options={ **SSH_DEFAULT_OPTIONS, "ServerAliveInterval": "30", + # Set ControlPersist to auto-close orphaned background ssh process + # in case dstack server shutdown is not graceful. "ControlPersist": "2m", }, batch_mode=True, @@ -202,7 +213,11 @@ def _resolve_conn_dir( temp_dir = TemporaryDirectory() return temp_dir, Path(temp_dir.name) - conn_dir = CONNECTIONS_DIR / f"{key.hostname}:{key.port}" / str(key.ports_to_forward) + conn_dir = ( + CONNECTIONS_DIR + / f"{key.hostname}:{key.port}" + / ",".join(map(str, key.ports_to_forward)) + ) conn_dir.mkdir(parents=True, exist_ok=True) # Connection_dir can have a long path that won't be accepted by the ssh command, # so we create a short temporary symlink. diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 50848fb225..2e41b2db2a 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -9,6 +9,7 @@ from dstack._internal.core.errors import DstackError, SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.runs import JobProvisioningData, JobRuntimeData +from dstack._internal.server import settings from dstack._internal.server.services.runner.client import LocalAddress from dstack._internal.server.services.runner.pool import ( InstanceConnection, @@ -70,7 +71,7 @@ def wrapper( ) return func(port_map, *args, **kwargs) - if not job_provisioning_data.dockerized: + if settings.SERVER_SSH_POOL_DISABLED or not job_provisioning_data.dockerized: # Connections from dstack-server to runner's sshd are expected to be short # as the `inactivity_duration` feature distinguishes user and server connections based on duration. # Do not re-use SSH connections for container-based backends. diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 8b1c4bb11e..0e81fdb774 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -153,6 +153,8 @@ os.getenv("DSTACK_SERVER_LOG_QUOTA_PER_JOB_HOUR", 50 * 1024 * 1024) # 50 MB ) +SERVER_SSH_POOL_DISABLED = os.getenv("DSTACK_SERVER_SSH_POOL_DISABLED", False) + # Development settings SQL_ECHO_ENABLED = os.getenv("DSTACK_SQL_ECHO_ENABLED") is not None From 0f9a9517677a7885855baa7b4308fd174a81bed3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 15:16:01 +0500 Subject: [PATCH 12/25] Check SSH connection health --- .../_internal/core/services/ssh/tunnel.py | 6 +++ .../_internal/server/services/runner/pool.py | 37 ++++++++++++++++++- .../_internal/server/services/runner/ssh.py | 7 +++- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index 9fede91111..f4d6a17f70 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -252,6 +252,12 @@ async def aclose(self) -> None: proc.stdout, ) + def check(self) -> bool: + proc = subprocess.run( + self.check_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + return proc.returncode == 0 + async def acheck(self) -> bool: proc = await asyncio.create_subprocess_exec( *self.check_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 3c438b67b9..7ff3642c5e 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -1,4 +1,5 @@ import threading +import time from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory @@ -26,6 +27,9 @@ CONNECTIONS_DIR = SERVER_DIR_PATH / "instance-connections" +MIN_ALIVE_CHECK_INTERVAL = 30 +"""How often (at most) `InstanceConnection.is_alive()` runs `ssh -O check`, in seconds.""" + @dataclass(frozen=True) class InstanceConnectionKey: @@ -77,7 +81,15 @@ def get_or_open( return None conn = self._connections.get(key) if conn is not None: - return conn + if conn.is_alive(): + return conn + # The master process is gone — evict and reopen. + logger.debug("Instance connection %s is dead, reopening", key) + self._connections.pop(key) + try: + conn.close() + except Exception: + logger.exception("Failed to close instance connection %s", key) conn = InstanceConnection(ssh_private_key, jpd, jrd) try: conn.open() @@ -147,6 +159,7 @@ def __init__( """ self._key = InstanceConnectionKey.from_jpd(jpd, jrd) self._ephemeral = ephemeral + self._last_verified_at: float = 0.0 self._temp_dir, self._effective_conn_dir = InstanceConnection._resolve_conn_dir( self._key, ephemeral ) @@ -177,6 +190,28 @@ def __init__( def open(self) -> None: self._tunnel.open() + self._last_verified_at = time.monotonic() + + def is_alive(self) -> bool: + """ + Verifies that the connection's SSH master process is alive: + + 1. The control socket exists (a stat). Catches cleanly exited masters (incl. ControlPersist). + 2. `ssh -O check`. Catches killed masters that left a stale socket file behind. + Rate-limited to once per `MIN_ALIVE_CHECK_INTERVAL`. + + Does not detect half-open TCP (ServerAliveInterval converts it into a clean exit) + or mid-request deaths (handled by the callers' drop-on-error pattern). + """ + if not self._control_socket_path.exists(): + return False + now = time.monotonic() + if now - self._last_verified_at < MIN_ALIVE_CHECK_INTERVAL: + return True + if not self._tunnel.check(): + return False + self._last_verified_at = now + return True def forwarded_paths(self) -> dict[int, Path]: """Returns a mapping from container port to the local UDS path.""" diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 2e41b2db2a..f36b161b6e 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -99,7 +99,12 @@ def wrapper( conn.close() return False - for _ in range(2): # cached, then one fresh reopen + # First try a cached connection and, if it's dead, a new connection. + # Connections already cover against + # a) cleanly-existed master (ControlPersist reap); and + # b) stale control socket file left by killed master. + # but we still want a fast retry in case master dies mid-request. + for _ in range(2): conn = instance_connection_pool.get_or_open( ssh_private_key=ssh_private_key, jpd=job_provisioning_data, From 52eb4a5b4b26be65089f57a931bfe0e9dbf44322 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 15:32:32 +0500 Subject: [PATCH 13/25] Tweak ssh options --- src/dstack/_internal/server/services/runner/pool.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 7ff3642c5e..38811ddc57 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -74,6 +74,11 @@ def get_or_open( jpd: JobProvisioningData, jrd: Optional[JobRuntimeData], ) -> Optional["InstanceConnection"]: + """ + Starts a new SSH connection or returns an existing one. + Existing connections are checked for health periodically + so that subsequent calls to `get_or_open()` eventually return a healthy connection. + """ key = InstanceConnectionKey.from_jpd(jpd, jrd) lock = self._get_access_lock(key) with lock: @@ -180,7 +185,9 @@ def __init__( ssh_proxies=InstanceConnection._get_proxies(ssh_private_key, jpd), options={ **SSH_DEFAULT_OPTIONS, - "ServerAliveInterval": "30", + # Auto-close half-opened connections (the instance not responding). + "ServerAliveInterval": "10", + "ServerAliveCountMax": "3", # Set ControlPersist to auto-close orphaned background ssh process # in case dstack server shutdown is not graceful. "ControlPersist": "2m", From 894b0158db0c831e2f4092baa589771ce4a28307 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 15:58:16 +0500 Subject: [PATCH 14/25] Drop retries from runner_ssh_tunnel --- .../pipeline_tasks/instances/check.py | 2 +- .../background/pipeline_tasks/jobs_running.py | 10 +- .../pipeline_tasks/jobs_terminating.py | 2 +- .../background/scheduled_tasks/metrics.py | 2 +- .../scheduled_tasks/prometheus_metrics.py | 2 +- .../server/services/jobs/__init__.py | 2 +- .../_internal/server/services/runner/pool.py | 7 +- .../_internal/server/services/runner/ssh.py | 145 ++++++++---------- src/dstack/_internal/server/settings.py | 1 + 9 files changed, 81 insertions(+), 92 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index b0ab28829a..486c83dbf6 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -374,7 +374,7 @@ async def _get_backend_for_provisioning_wait( ) -@runner_ssh_tunnel(retries=1) +@runner_ssh_tunnel def _check_instance_inner( addresses: Mapping[int, runner_client.LocalAddress], *, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 00e9dff92d..98e5967cb8 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1309,7 +1309,7 @@ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> boo return False -@runner_ssh_tunnel(retries=1) +@runner_ssh_tunnel def _process_provisioning_with_shim( addresses: Mapping[int, client.LocalAddress], run: Run, @@ -1436,7 +1436,7 @@ class _SyncShimPullingStateResult: image_pull_progress: Optional[ImagePullProgress] = None -@runner_ssh_tunnel(retries=1) +@runner_ssh_tunnel def _get_runner_availability(addresses: Mapping[int, client.LocalAddress]) -> _RunnerAvailability: runner_client = client.RunnerClient.from_address(addresses[DSTACK_RUNNER_HTTP_PORT]) if runner_client.healthcheck() is None: @@ -1444,7 +1444,7 @@ def _get_runner_availability(addresses: Mapping[int, client.LocalAddress]) -> _R return _RunnerAvailability.AVAILABLE -@runner_ssh_tunnel() +@runner_ssh_tunnel def _sync_shim_pulling_state( addresses: Mapping[int, client.LocalAddress], job_model: JobModel, @@ -1526,7 +1526,7 @@ class _SubmitJobToRunnerResult: job_runtime_data: Optional[JobRuntimeData] = None -@runner_ssh_tunnel(retries=1) +@runner_ssh_tunnel def _submit_job_to_runner( addresses: Mapping[int, client.LocalAddress], run: Run, @@ -1596,7 +1596,7 @@ class _ProcessRunningResult: job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) -@runner_ssh_tunnel() +@runner_ssh_tunnel def _process_running( addresses: Mapping[int, client.LocalAddress], run_model: RunModel, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index e2d19e341f..fe2e64ca4e 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -853,7 +853,7 @@ async def _stop_container( return True -@runner_ssh_tunnel() +@runner_ssh_tunnel def _shim_submit_stop(addresses: Mapping[int, client.LocalAddress], job_model: JobModel) -> bool: shim_client = client.ShimClient.from_address(addresses[DSTACK_SHIM_HTTP_PORT]) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py index 4540ddc1bd..1febe7fa52 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py @@ -165,7 +165,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint] ) -@runner_ssh_tunnel(retries=1) +@runner_ssh_tunnel def _pull_runner_metrics( addresses: Mapping[int, client.LocalAddress], ) -> Optional[MetricsResponse]: diff --git a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py index 7acd85b8f4..96b8cb7742 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py @@ -145,7 +145,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[str]: return res -@runner_ssh_tunnel(retries=1) +@runner_ssh_tunnel def _pull_job_metrics( addresses: Mapping[int, client.LocalAddress], task_id: uuid.UUID ) -> Optional[str]: diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 54d882a736..2d149ab77b 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -347,7 +347,7 @@ async def stop_runner(job_model: JobModel, instance_model: InstanceModel): logger.debug("%s: failed to stop runner", fmt(job_model)) -@runner_ssh_tunnel() +@runner_ssh_tunnel def _stop_runner( addresses: Mapping[int, client.LocalAddress], job_model: JobModel, diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 38811ddc57..d29f2d0ac9 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -16,7 +16,11 @@ SSHTunnel, UnixSocket, ) -from dstack._internal.server.settings import SERVER_DIR_PATH, SERVER_TMP_PATH +from dstack._internal.server.settings import ( + SERVER_DIR_PATH, + SERVER_SSH_CONNECT_TIMEOUT, + SERVER_TMP_PATH, +) from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir @@ -185,6 +189,7 @@ def __init__( ssh_proxies=InstanceConnection._get_proxies(ssh_private_key, jpd), options={ **SSH_DEFAULT_OPTIONS, + "ConnectTimeout": str(SERVER_SSH_CONNECT_TIMEOUT), # Auto-close half-opened connections (the instance not responding). "ServerAliveInterval": "10", "ServerAliveCountMax": "3", diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index f36b161b6e..c57d23564f 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -1,5 +1,4 @@ import functools -import time from collections.abc import Mapping from typing import Callable, Literal, Optional, TypeVar, Union @@ -24,13 +23,10 @@ def runner_ssh_tunnel( - retries: int = 3, retry_interval: float = 1 + func: Callable[Concatenate[Mapping[int, LocalAddress], P], R], ) -> Callable[ - [Callable[Concatenate[Mapping[int, LocalAddress], P], R]], - Callable[ - Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], - Union[Literal[False], R], - ], + Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], + Union[Literal[False], R], ]: """ A decorator that opens an SSH tunnel to the runner instance for port forwarding. @@ -43,83 +39,70 @@ def runner_ssh_tunnel( Always forwards the same ports for the given instance/job so that connection is reused across all calls. In case of blocks, each job uses a separate connection as the runner host port differs. - `retries` and `retry_interval` apply only if connection pooling is not used. + There are no retries: a transient transport failure fails the call, + and the callers must retry. In high-latency setups, tune `DSTACK_SERVER_SSH_CONNECT_TIMEOUT`. """ - def decorator( - func: Callable[Concatenate[Mapping[int, LocalAddress], P], R], - ) -> Callable[ - Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], - Union[Literal[False], R], - ]: - @functools.wraps(func) - def wrapper( - ssh_private_key: PrivateKeyOrPair, - job_provisioning_data: JobProvisioningData, - job_runtime_data: Optional[JobRuntimeData], - *args: P.args, - **kwargs: P.kwargs, - ) -> Union[Literal[False], R]: - """ - Returns: - is successful - """ - if job_provisioning_data.backend == BackendType.LOCAL: - # without SSH - port_map = InstanceConnection.get_container_to_host_port_map( - job_provisioning_data, job_runtime_data - ) - return func(port_map, *args, **kwargs) + @functools.wraps(func) + def wrapper( + ssh_private_key: PrivateKeyOrPair, + job_provisioning_data: JobProvisioningData, + job_runtime_data: Optional[JobRuntimeData], + *args: P.args, + **kwargs: P.kwargs, + ) -> Union[Literal[False], R]: + """ + Returns: + is successful + """ + if job_provisioning_data.backend == BackendType.LOCAL: + # without SSH + port_map = InstanceConnection.get_container_to_host_port_map( + job_provisioning_data, job_runtime_data + ) + return func(port_map, *args, **kwargs) - if settings.SERVER_SSH_POOL_DISABLED or not job_provisioning_data.dockerized: - # Connections from dstack-server to runner's sshd are expected to be short - # as the `inactivity_duration` feature distinguishes user and server connections based on duration. - # Do not re-use SSH connections for container-based backends. - # TODO: Drop `inactivity_duration` dependence on connection duration and re-use connections. - for attempt in range(retries): - if attempt > 0: - time.sleep(retry_interval) - conn = InstanceConnection( - ssh_private_key=ssh_private_key, - jpd=job_provisioning_data, - jrd=job_runtime_data, - ephemeral=True, - ) - try: - conn.open() - except SSHError: - continue - try: - return func(conn.forwarded_paths(), *args, **kwargs) - except (SSHError, requests.ConnectionError): - continue # connection-level failure, retry with a fresh connection - except (DstackError, requests.RequestException): - return False - finally: - conn.close() + if settings.SERVER_SSH_POOL_DISABLED or not job_provisioning_data.dockerized: + # Connections from dstack-server to runner's sshd are expected to be short + # as the `inactivity_duration` feature distinguishes user and server connections based on duration. + # Do not re-use SSH connections for container-based backends. + # TODO: Drop `inactivity_duration` dependence on connection duration and re-use connections. + conn = InstanceConnection( + ssh_private_key=ssh_private_key, + jpd=job_provisioning_data, + jrd=job_runtime_data, + ephemeral=True, + ) + try: + conn.open() + except SSHError: return False + try: + return func(conn.forwarded_paths(), *args, **kwargs) + except (DstackError, requests.RequestException): + return False + finally: + conn.close() - # First try a cached connection and, if it's dead, a new connection. - # Connections already cover against - # a) cleanly-existed master (ControlPersist reap); and - # b) stale control socket file left by killed master. - # but we still want a fast retry in case master dies mid-request. - for _ in range(2): - conn = instance_connection_pool.get_or_open( - ssh_private_key=ssh_private_key, - jpd=job_provisioning_data, - jrd=job_runtime_data, - ) - if conn is None: - return False # couldn't establish at all - try: - return func(conn.forwarded_paths(), *args, **kwargs) - except (SSHError, requests.ConnectionError): - instance_connection_pool.drop(conn.key) # dead ssh connection, re-open - except (DstackError, requests.RequestException): - return False # reached runner, app-level fail; don't re-open ssh connection - return False - - return wrapper + # First try a cached connection and, if it's dead, a new connection. + # Connections already cover against + # a) cleanly-existed master (ControlPersist reap); and + # b) stale control socket file left by killed master. + # but we still want a fast retry in case master dies mid-request. + for _ in range(2): + conn = instance_connection_pool.get_or_open( + ssh_private_key=ssh_private_key, + jpd=job_provisioning_data, + jrd=job_runtime_data, + ) + if conn is None: + return False # couldn't establish at all + try: + return func(conn.forwarded_paths(), *args, **kwargs) + except (SSHError, requests.ConnectionError): + instance_connection_pool.drop(conn.key) # dead ssh connection, re-open + except (DstackError, requests.RequestException): + return False # reached runner, app-level fail; don't re-open ssh connection + return False - return decorator + return wrapper diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 0e81fdb774..ce54735579 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -154,6 +154,7 @@ ) SERVER_SSH_POOL_DISABLED = os.getenv("DSTACK_SERVER_SSH_POOL_DISABLED", False) +SERVER_SSH_CONNECT_TIMEOUT = int(os.getenv("DSTACK_SERVER_SSH_CONNECT_TIMEOUT", 3)) # Development settings From d271fcb7a3569445c1d261ccf24158f391598c79 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 4 Jun 2026 15:58:44 +0500 Subject: [PATCH 15/25] Update env docs --- mkdocs/docs/reference/env.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mkdocs/docs/reference/env.md b/mkdocs/docs/reference/env.md index c88ccce4c5..76d6d58393 100644 --- a/mkdocs/docs/reference/env.md +++ b/mkdocs/docs/reference/env.md @@ -141,7 +141,8 @@ For more details on the options below, refer to the [server deployment](../guide - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY } – A default Docker registry to use for job images that do not specify an explicit registry. E.g., if set to `registry.example`, then `image: ubuntu` becomes equivalent to `image: registry.example/ubuntu`. **Note**: This setting should only be used for configuring registries that act as a pull-through cache for Docker Hub. The default `dstack` images are also pulled from the configured registry. - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME } – Username for authenticating with the default Docker registry. See `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`. - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD } – Password for authenticating with the default Docker registry. Applied only when the image has no explicit registry and the run configuration does not specify `registry_auth`. **Note**: The value may be visible to anyone who can SSH into instances managed by `dstack`, which usually includes all users of that `dstack` server. -- `DSTACK_SERVER_SSH_POOL_DISABLED`{ #DSTACK_SERVER_SSH_POOL_DISABLED } – Disables the reuse of server-instance SSH connections. By default, SHH connections are reused, and each active instance may consume ~2-10MB of server RAM. Set this to save RAM at the expense of opening an SSH connection on every server-instance communication. +- `DSTACK_SERVER_SSH_POOL_DISABLED`{ #DSTACK_SERVER_SSH_POOL_DISABLED } – Disables the reuse of server-instance SSH connections. By default, SSH connections are reused, and each active instance may consume ~2-10MB of server RAM. Set this to save RAM at the expense of opening an SSH connection on every server-instance communication. +- `DSTACK_SERVER_SSH_CONNECT_TIMEOUT`{ #DSTACK_SERVER_SSH_CONNECT_TIMEOUT } – The SSH `ConnectTimeout` for server-instance connections, in seconds. Defaults to `3`. Increase if there are high-latency links between the server and instances. ??? info "Internal environment variables" The following environment variables are intended for development purposes: From 3effce4cc8b02d35a5ab52bfacd1c6a40702deee Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 10:50:45 +0500 Subject: [PATCH 16/25] Clean up locks with WeakValueDictionary --- .../_internal/server/services/runner/pool.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index d29f2d0ac9..8f440ae407 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -4,6 +4,7 @@ from pathlib import Path from tempfile import TemporaryDirectory from typing import Collection, Optional, Union +from weakref import WeakValueDictionary from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT from dstack._internal.core.errors import SSHError @@ -68,7 +69,11 @@ class InstanceConnectionPool: def __init__(self): self._connections: dict[InstanceConnectionKey, InstanceConnection] = {} - self._access_locks: dict[InstanceConnectionKey, threading.Lock] = {} + # Use `WeakValueDictionary` for automatic GC of unused locks and avoid manual refcounting. + # A lock is expected to exist only while a thread holds a strong reference to it. + self._access_locks: WeakValueDictionary[InstanceConnectionKey, threading.Lock] = ( + WeakValueDictionary({}) + ) self._access_locks_lock = threading.Lock() self._closed = False @@ -128,7 +133,9 @@ def close_all(self) -> None: """ with self._access_locks_lock: self._closed = True - keys = list(self._access_locks.keys()) + # self._connections holds cached connections, and + # self._access_locks may hold mid-open connections not yet cached. + keys = set(self._connections) | set(self._access_locks.keys()) logger.debug("Closing %d instance connection(s)", len(keys)) for key in keys: self.drop(key) @@ -262,8 +269,7 @@ def _resolve_conn_dir( conn_dir = ( CONNECTIONS_DIR - / f"{key.hostname}:{key.port}" - / ",".join(map(str, key.ports_to_forward)) + / f"{key.hostname}:{key.port},{','.join(map(str, key.ports_to_forward))}" ) conn_dir.mkdir(parents=True, exist_ok=True) # Connection_dir can have a long path that won't be accepted by the ssh command, @@ -281,7 +287,7 @@ def _get_host_port_to_uds_map( conn_dir: Path, ports_to_forward: Collection[int], ) -> dict[int, Path]: - return {port: conn_dir / str(port) for port in ports_to_forward} + return {port: conn_dir / f"{port}.sock" for port in ports_to_forward} @staticmethod def _get_forwarded_sockets(host_port_to_uds_map: dict[int, Path]) -> list[SocketPair]: From ceb093fd021847e187b7f33318ec52ab9cd1788e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 11:00:30 +0500 Subject: [PATCH 17/25] Clean up control socket --- src/dstack/_internal/server/services/runner/pool.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 8f440ae407..05678c07da 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -241,6 +241,9 @@ def forwarded_paths(self) -> dict[int, Path]: def close(self) -> None: self._tunnel.close() + # If the master was killed without cleaning up its control socket, + # remove the socket so that the master can re-open. + self._control_socket_path.unlink(missing_ok=True) @property def key(self) -> InstanceConnectionKey: From 7dec8d4c3a3132ed7bfe6e9c4656b11699ab3262 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 11:33:00 +0500 Subject: [PATCH 18/25] Surive tmp cleanup --- .../_internal/server/services/runner/pool.py | 32 ++++++++++++++----- src/dstack/_internal/server/settings.py | 4 --- src/dstack/_internal/utils/path.py | 6 ++-- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 05678c07da..3626898afe 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -1,3 +1,4 @@ +import os import threading import time from dataclasses import dataclass @@ -20,7 +21,6 @@ from dstack._internal.server.settings import ( SERVER_DIR_PATH, SERVER_SSH_CONNECT_TIMEOUT, - SERVER_TMP_PATH, ) from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import FileContent, make_tmp_symlink_to_dir @@ -176,10 +176,11 @@ def __init__( self._key = InstanceConnectionKey.from_jpd(jpd, jrd) self._ephemeral = ephemeral self._last_verified_at: float = 0.0 - self._temp_dir, self._effective_conn_dir = InstanceConnection._resolve_conn_dir( - self._key, ephemeral + self._temp_dir, self._effective_conn_dir, self._real_conn_dir = ( + InstanceConnection._resolve_conn_dir(self._key, ephemeral) ) self._control_socket_path = self._effective_conn_dir / "control.sock" + self._real_control_socket_path = self._real_conn_dir / "control.sock" self._container_to_host_port_map = InstanceConnection.get_container_to_host_port_map( jpd, jrd ) @@ -208,6 +209,12 @@ def __init__( ) def open(self) -> None: + # A control socket left by a killed master or by a master that exited after + # its tmp symlink was deleted prevents ssh from becoming a mux master + # ("ControlSocket ... already exists, disabling multiplexing"). + # Remove it unless it's served by a live master (then open() attaches to it). + if self._real_control_socket_path.exists() and not self._tunnel.check(): + self._real_control_socket_path.unlink(missing_ok=True) self._tunnel.open() self._last_verified_at = time.monotonic() @@ -229,6 +236,11 @@ def is_alive(self) -> bool: return True if not self._tunnel.check(): return False + # Keep the symlink fresh so that age-based /tmp cleanup is less likely to remove it. + try: + os.utime(self._effective_conn_dir, follow_symlinks=False) + except OSError: + pass self._last_verified_at = now return True @@ -265,10 +277,14 @@ def get_container_to_host_port_map( @staticmethod def _resolve_conn_dir( key: InstanceConnectionKey, ephemeral: bool - ) -> tuple[TemporaryDirectory, Path]: + ) -> tuple[TemporaryDirectory, Path, Path]: + """ + Returns (temp dir to retain, dir to be used by ssh, real conn dir). + """ if ephemeral: temp_dir = TemporaryDirectory() - return temp_dir, Path(temp_dir.name) + path = Path(temp_dir.name) + return temp_dir, path, path conn_dir = ( CONNECTIONS_DIR @@ -277,13 +293,13 @@ def _resolve_conn_dir( conn_dir.mkdir(parents=True, exist_ok=True) # Connection_dir can have a long path that won't be accepted by the ssh command, # so we create a short temporary symlink. + # The symlink may be removed by age-based /tmp cleanup while the connection is still alive. + # The connection will be reopened with a fresh symlink, attaching to the still-running master. temp_dir, conn_symlink_dir = make_tmp_symlink_to_dir( dirpath=conn_dir, symlink_dirname="connection", - # Using dstack's own tmp dir to avoid age-based tmp cleanup. - base_dir=SERVER_TMP_PATH, ) - return temp_dir, conn_symlink_dir + return temp_dir, conn_symlink_dir, conn_dir @staticmethod def _get_host_port_to_uds_map( diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index ce54735579..d4bc5829e9 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -21,10 +21,6 @@ SERVER_DATA_DIR_PATH = SERVER_DIR_PATH / "data" SERVER_DATA_DIR_PATH.mkdir(parents=True, exist_ok=True) -SERVER_TMP_PATH = SERVER_DIR_PATH / "tmp" -"""SERVER_TMP_PATH can be used as dstack's own /tmp when age-based cleaning for /tmp is not desirable""" -SERVER_TMP_PATH.mkdir(parents=True, exist_ok=True) - DATABASE_URL = os.getenv( "DSTACK_DATABASE_URL", f"sqlite+aiosqlite:///{str(SERVER_DATA_DIR_PATH.absolute())}/sqlite.db" ) diff --git a/src/dstack/_internal/utils/path.py b/src/dstack/_internal/utils/path.py index 5131831cf8..07b8fdd664 100644 --- a/src/dstack/_internal/utils/path.py +++ b/src/dstack/_internal/utils/path.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path, PurePath, PurePosixPath from tempfile import TemporaryDirectory -from typing import Optional, Union +from typing import Union PathLike = Union[str, os.PathLike] @@ -59,9 +59,9 @@ def is_absolute_posix_path(path: PathLike) -> bool: def make_tmp_symlink_to_dir( - dirpath: PathLike, symlink_dirname: str, base_dir: Optional[PathLike] = None + dirpath: PathLike, symlink_dirname: str ) -> tuple[TemporaryDirectory, Path]: - temp_dir = TemporaryDirectory(dir=base_dir) + temp_dir = TemporaryDirectory() symlink_dir = Path(temp_dir.name) / symlink_dirname symlink_dir.symlink_to(dirpath, target_is_directory=True) return temp_dir, symlink_dir From 6167749b4bef0b46d95185da632a8c959a59fe37 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 12:39:22 +0500 Subject: [PATCH 19/25] Make ssh pool opt-in --- mkdocs/docs/reference/env.md | 1 - src/dstack/_internal/server/app.py | 1 + .../_internal/server/services/runner/pool.py | 23 +++++++++++++++---- .../_internal/server/services/runner/ssh.py | 2 +- src/dstack/_internal/server/settings.py | 2 +- 5 files changed, 22 insertions(+), 7 deletions(-) diff --git a/mkdocs/docs/reference/env.md b/mkdocs/docs/reference/env.md index 76d6d58393..086ea80ad8 100644 --- a/mkdocs/docs/reference/env.md +++ b/mkdocs/docs/reference/env.md @@ -141,7 +141,6 @@ For more details on the options below, refer to the [server deployment](../guide - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY } – A default Docker registry to use for job images that do not specify an explicit registry. E.g., if set to `registry.example`, then `image: ubuntu` becomes equivalent to `image: registry.example/ubuntu`. **Note**: This setting should only be used for configuring registries that act as a pull-through cache for Docker Hub. The default `dstack` images are also pulled from the configured registry. - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_USERNAME } – Username for authenticating with the default Docker registry. See `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`. - `DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD`{ #DSTACK_SERVER_DEFAULT_DOCKER_REGISTRY_PASSWORD } – Password for authenticating with the default Docker registry. Applied only when the image has no explicit registry and the run configuration does not specify `registry_auth`. **Note**: The value may be visible to anyone who can SSH into instances managed by `dstack`, which usually includes all users of that `dstack` server. -- `DSTACK_SERVER_SSH_POOL_DISABLED`{ #DSTACK_SERVER_SSH_POOL_DISABLED } – Disables the reuse of server-instance SSH connections. By default, SSH connections are reused, and each active instance may consume ~2-10MB of server RAM. Set this to save RAM at the expense of opening an SSH connection on every server-instance communication. - `DSTACK_SERVER_SSH_CONNECT_TIMEOUT`{ #DSTACK_SERVER_SSH_CONNECT_TIMEOUT } – The SSH `ConnectTimeout` for server-instance connections, in seconds. Defaults to `3`. Increase if there are high-latency links between the server and instances. ??? info "Internal environment variables" diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index ef8c9e6b54..f28da6296f 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -169,6 +169,7 @@ async def lifespan(app: FastAPI): ) if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None: init_default_storage() + await run_async(instance_connection_pool.startup_cleanup) scheduler = None pipeline_manager = None if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 3626898afe..e372402850 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -1,4 +1,5 @@ import os +import shutil import threading import time from dataclasses import dataclass @@ -65,6 +66,9 @@ class InstanceConnectionPool: NOTE: The pool does not currently intended for arbitrary ports forwarding, only for shim and runner ports. E.g. it cannot be used to forward services ports for probes or router-worker communication. This simplified model allows forwarding the same ports for the given host:port and reusing the connection across all calls. + + Incompatible with multiple server processes sharing the same server dir: + connection dirs and control sockets are assumed to be owned by a single process. """ def __init__(self): @@ -125,6 +129,14 @@ def drop(self, key: InstanceConnectionKey) -> None: except Exception: logger.exception("Failed to close instance connection %s", key) + def startup_cleanup(self) -> None: + """ + Removes connection dirs left by a previous server process (e.g. after SIGKILL). + Must be called on server startup before the pool is used. + Leftover live masters are reaped by `ControlPersist`. + """ + shutil.rmtree(CONNECTIONS_DIR, ignore_errors=True) + def close_all(self) -> None: """ Closes all connections and prevents new ones from being opened. @@ -253,9 +265,12 @@ def forwarded_paths(self) -> dict[int, Path]: def close(self) -> None: self._tunnel.close() - # If the master was killed without cleaning up its control socket, - # remove the socket so that the master can re-open. - self._control_socket_path.unlink(missing_ok=True) + # Remove a stale control.sock left by a killed master, forwarded UDS files + # (ssh does not unlink them on exit), and the dir itself, so that + # CONNECTIONS_DIR does not accumulate dirs of gone instances. + # A master that survives close() because it is unreachable via a deleted + # symlink is reaped by ControlPersist. + shutil.rmtree(self._real_conn_dir, ignore_errors=True) @property def key(self) -> InstanceConnectionKey: @@ -294,7 +309,7 @@ def _resolve_conn_dir( # Connection_dir can have a long path that won't be accepted by the ssh command, # so we create a short temporary symlink. # The symlink may be removed by age-based /tmp cleanup while the connection is still alive. - # The connection will be reopened with a fresh symlink, attaching to the still-running master. + # The connection dir will be removed and the connection is re-opened. temp_dir, conn_symlink_dir = make_tmp_symlink_to_dir( dirpath=conn_dir, symlink_dirname="connection", diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index c57d23564f..cafb6af5d7 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -62,7 +62,7 @@ def wrapper( ) return func(port_map, *args, **kwargs) - if settings.SERVER_SSH_POOL_DISABLED or not job_provisioning_data.dockerized: + if not settings.SERVER_SSH_POOL_ENABLED or not job_provisioning_data.dockerized: # Connections from dstack-server to runner's sshd are expected to be short # as the `inactivity_duration` feature distinguishes user and server connections based on duration. # Do not re-use SSH connections for container-based backends. diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index d4bc5829e9..a095f433f9 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -149,7 +149,7 @@ os.getenv("DSTACK_SERVER_LOG_QUOTA_PER_JOB_HOUR", 50 * 1024 * 1024) # 50 MB ) -SERVER_SSH_POOL_DISABLED = os.getenv("DSTACK_SERVER_SSH_POOL_DISABLED", False) +SERVER_SSH_POOL_ENABLED = os.getenv("DSTACK_SERVER_SSH_POOL_ENABLED") is not None SERVER_SSH_CONNECT_TIMEOUT = int(os.getenv("DSTACK_SERVER_SSH_CONNECT_TIMEOUT", 3)) # Development settings From fda0c80c7c81dd43eda9123490b8810dbf5afc01 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 13:20:19 +0500 Subject: [PATCH 20/25] Fix tests --- .../_internal/server/services/runner/pool.py | 2 +- .../_internal/server/services/runner/ssh.py | 12 ++-- .../test_instances/test_check.py | 2 +- .../pipeline_tasks/test_running_jobs.py | 64 +++++++++---------- .../pipeline_tasks/test_terminating_jobs.py | 6 +- .../scheduled_tasks/test_metrics.py | 4 +- .../test_prometheus_metrics.py | 6 +- 7 files changed, 50 insertions(+), 46 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index e372402850..9aebc8fc72 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -108,8 +108,8 @@ def get_or_open( conn.close() except Exception: logger.exception("Failed to close instance connection %s", key) - conn = InstanceConnection(ssh_private_key, jpd, jrd) try: + conn = InstanceConnection(ssh_private_key, jpd, jrd) conn.open() except SSHError: # error logged in tunnel diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index cafb6af5d7..4ddeab7812 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -67,13 +67,13 @@ def wrapper( # as the `inactivity_duration` feature distinguishes user and server connections based on duration. # Do not re-use SSH connections for container-based backends. # TODO: Drop `inactivity_duration` dependence on connection duration and re-use connections. - conn = InstanceConnection( - ssh_private_key=ssh_private_key, - jpd=job_provisioning_data, - jrd=job_runtime_data, - ephemeral=True, - ) try: + conn = InstanceConnection( + ssh_private_key=ssh_private_key, + jpd=job_provisioning_data, + jrd=job_runtime_data, + ephemeral=True, + ) conn.open() except SSHError: return False diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py index b555556881..33e57df016 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py @@ -543,7 +543,7 @@ def shim_client_mock( mock.list_tasks.return_value = TaskListResponse(tasks=[]) mock.is_safe_to_restart.return_value = False monkeypatch.setattr( - "dstack._internal.server.services.runner.client.ShimClient", + "dstack._internal.server.services.runner.client.ShimClient.from_address", Mock(return_value=mock), ) return mock diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index 35a129e0f6..e308b89ce8 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -36,6 +36,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint, VolumeStatus +from dstack._internal.core.services.ssh.tunnel import SSHTunnel from dstack._internal.server import settings as server_settings from dstack._internal.server.background.pipeline_tasks.jobs_running import ( ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS, @@ -61,7 +62,6 @@ TaskStatus, ) from dstack._internal.server.services.runner.client import RunnerClient, ShimClient -from dstack._internal.server.services.runner.ssh import SSHTunnel from dstack._internal.server.services.runs.replicas import RouterEnvStatus from dstack._internal.server.services.volumes import volume_model_to_volume from dstack._internal.server.testing.common import ( @@ -116,7 +116,7 @@ def worker() -> JobRunningWorker: @pytest.fixture def ssh_tunnel_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: mock = MagicMock(spec_set=SSHTunnel) - monkeypatch.setattr("dstack._internal.server.services.runner.ssh.SSHTunnel", mock) + monkeypatch.setattr("dstack._internal.server.services.runner.pool.SSHTunnel", mock) return mock @@ -126,7 +126,8 @@ def shim_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: mock.healthcheck.return_value = HealthcheckResponse(service="dstack-shim", version="latest") mock.get_task.return_value.image_pull_progress = None monkeypatch.setattr( - "dstack._internal.server.services.runner.client.ShimClient", Mock(return_value=mock) + "dstack._internal.server.services.runner.client.ShimClient.from_address", + Mock(return_value=mock), ) return mock @@ -138,7 +139,8 @@ def runner_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: service="dstack-runner", version="0.0.1.dev2" ) monkeypatch.setattr( - "dstack._internal.server.services.runner.client.RunnerClient", Mock(return_value=mock) + "dstack._internal.server.services.runner.client.RunnerClient.from_address", + Mock(return_value=mock), ) return mock @@ -481,9 +483,9 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, patch( "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", @@ -561,7 +563,7 @@ async def test_runs_provisioning_job( before_processed_at = job.last_processed_at with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch.object(RunnerClient, "_healthcheck") as healthcheck_mock, patch.object(RunnerClient, "submit_job") as submit_job_mock, patch.object(RunnerClient, "upload_code") as upload_code_mock, @@ -1067,14 +1069,13 @@ async def test_pulling_shim_failed( ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch("dstack._internal.server.services.runner.ssh.time.sleep"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, ): from dstack._internal.core.errors import SSHError ssh_tunnel_cls.side_effect = SSHError await _process_job(session, worker, job) - assert ssh_tunnel_cls.call_count == 3 + assert ssh_tunnel_cls.call_count == 1 await session.refresh(job) events = await list_events(session) @@ -1084,15 +1085,14 @@ async def test_pulling_shim_failed( assert events[0].message == "Job became unreachable" with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch("dstack._internal.server.services.runner.ssh.time.sleep"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, freeze_time(job.disconnected_at + timedelta(minutes=5)), ): from dstack._internal.core.errors import SSHError ssh_tunnel_cls.side_effect = SSHError await _process_job(session, worker, job) - assert ssh_tunnel_cls.call_count == 3 + assert ssh_tunnel_cls.call_count == 1 await session.refresh(job) assert job.status == JobStatus.TERMINATING @@ -1168,11 +1168,12 @@ async def test_provisioning_shim_force_stop_if_already_running_api_v1( instance_assigned=True, ) monkeypatch.setattr( - "dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock()) + "dstack._internal.server.services.runner.pool.SSHTunnel", + Mock(return_value=MagicMock()), ) shim_client_mock = Mock() monkeypatch.setattr( - "dstack._internal.server.services.runner.client.ShimClient", + "dstack._internal.server.services.runner.client.ShimClient.from_address", Mock(return_value=shim_client_mock), ) shim_client_mock.healthcheck.return_value = HealthcheckResponse( @@ -1243,9 +1244,9 @@ async def test_master_job_waits_for_workers( await session.commit() with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel"), patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -1342,9 +1343,9 @@ async def test_updates_running_job( with ( patch.object(server_settings, "SERVER_DIR_PATH", tmp_path), - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -1365,9 +1366,9 @@ async def test_updates_running_job( await session.commit() with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -1411,12 +1412,11 @@ async def test_running_job_disconnect_retries_then_terminates( ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch("dstack._internal.server.services.runner.ssh.time.sleep"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, ): ssh_tunnel_cls.side_effect = SSHError await _process_job(session, worker, job) - assert ssh_tunnel_cls.call_count == 3 + assert ssh_tunnel_cls.call_count == 1 await session.refresh(job) events = await list_events(session) @@ -1426,13 +1426,12 @@ async def test_running_job_disconnect_retries_then_terminates( assert events[0].message == "Job became unreachable" with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, - patch("dstack._internal.server.services.runner.ssh.time.sleep"), + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, freeze_time(job.disconnected_at + timedelta(minutes=5)), ): ssh_tunnel_cls.side_effect = SSHError await _process_job(session, worker, job) - assert ssh_tunnel_cls.call_count == 3 + assert ssh_tunnel_cls.call_count == 1 await session.refresh(job) assert job.status == JobStatus.TERMINATING @@ -1537,9 +1536,9 @@ async def test_inactivity_duration( instance_assigned=True, ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -1649,9 +1648,9 @@ async def test_gpu_utilization( ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as ssh_tunnel_cls, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as runner_client_cls, ): runner_client_mock = runner_client_cls.return_value @@ -2127,7 +2126,8 @@ async def test_does_not_terminate_job_when_instance_access_is_valid( session=session, run=run, status=job_status, - job_provisioning_data=get_job_provisioning_data(dockerized=False), + # dockerized=True so that the shim port is forwarded for the PULLING case + job_provisioning_data=get_job_provisioning_data(dockerized=True), instance=instance, instance_assigned=True, ) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py index 6bc3a433ba..5ec769519b 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py @@ -384,8 +384,10 @@ async def test_terminates_job( await session.commit() with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch("dstack._internal.server.services.runner.client.ShimClient") as ShimClientMock, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as SSHTunnelMock, + patch( + "dstack._internal.server.services.runner.client.ShimClient.from_address" + ) as ShimClientMock, ): shim_client_mock = ShimClientMock.return_value await worker.process(_job_to_pipeline_item(job)) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py index df52dd88e2..1e3900a449 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py @@ -64,9 +64,9 @@ async def test_collects_metrics(self, test_db, session: AsyncSession): instance=instance, ) with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, + patch("dstack._internal.server.services.runner.pool.SSHTunnel") as SSHTunnelMock, patch( - "dstack._internal.server.services.runner.client.RunnerClient" + "dstack._internal.server.services.runner.client.RunnerClient.from_address" ) as RunnerClientMock, ): runner_client_mock = RunnerClientMock.return_value diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py index 80961d5c10..0775723b4d 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py @@ -73,12 +73,14 @@ async def job(self, request: pytest.FixtureRequest, session: AsyncSession) -> Jo @pytest.fixture def ssh_tunnel_mock(self) -> Generator[Mock, None, None]: - with patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock: + with patch("dstack._internal.server.services.runner.pool.SSHTunnel") as SSHTunnelMock: yield SSHTunnelMock @pytest.fixture def shim_client_mock(self) -> Generator[Mock, None, None]: - with patch("dstack._internal.server.services.runner.client.ShimClient") as ShimClientMock: + with patch( + "dstack._internal.server.services.runner.client.ShimClient.from_address" + ) as ShimClientMock: yield ShimClientMock.return_value @freeze_time(datetime(2023, 1, 2, 3, 5, 20, tzinfo=timezone.utc)) From 9e97814aecd47bf53eab52c6b1abdecc44885eb6 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 13:38:32 +0500 Subject: [PATCH 21/25] Minor fixes --- src/dstack/_internal/server/services/runner/pool.py | 8 ++++---- src/dstack/_internal/server/services/runner/ssh.py | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 9aebc8fc72..8e9827e00f 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -63,7 +63,7 @@ class InstanceConnectionPool: A pool of SSH connections to instances' host sshd (VM-based) or runner sshd (container-based) for forwarding shim and runner ports. - NOTE: The pool does not currently intended for arbitrary ports forwarding, only for shim and runner ports. + NOTE: The pool is not currently intended for arbitrary ports forwarding, only for shim and runner ports. E.g. it cannot be used to forward services ports for probes or router-worker communication. This simplified model allows forwarding the same ports for the given host:port and reusing the connection across all calls. @@ -76,7 +76,7 @@ def __init__(self): # Use `WeakValueDictionary` for automatic GC of unused locks and avoid manual refcounting. # A lock is expected to exist only while a thread holds a strong reference to it. self._access_locks: WeakValueDictionary[InstanceConnectionKey, threading.Lock] = ( - WeakValueDictionary({}) + WeakValueDictionary() ) self._access_locks_lock = threading.Lock() self._closed = False @@ -203,7 +203,7 @@ def __init__( self._tunnel = SSHTunnel( destination=f"{jpd.username}@{jpd.hostname}", port=jpd.ssh_port, - identity=InstanceConnection._get_identity(ssh_private_key, jpd), + identity=InstanceConnection._get_identity(ssh_private_key), control_sock_path=self._control_socket_path, forwarded_sockets=self._get_forwarded_sockets(self._host_port_to_uds_map), ssh_proxies=InstanceConnection._get_proxies(ssh_private_key, jpd), @@ -334,7 +334,7 @@ def _get_forwarded_sockets(host_port_to_uds_map: dict[int, Path]) -> list[Socket ] @staticmethod - def _get_identity(ssh_private_key: PrivateKeyOrPair, jpd: JobProvisioningData) -> FileContent: + def _get_identity(ssh_private_key: PrivateKeyOrPair) -> FileContent: if isinstance(ssh_private_key, tuple): ssh_private_key, _ = ssh_private_key return FileContent(ssh_private_key) diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 4ddeab7812..dadc7c32b1 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -15,9 +15,7 @@ PrivateKeyOrPair, instance_connection_pool, ) -from dstack._internal.utils.logging import get_logger -logger = get_logger(__name__) P = ParamSpec("P") R = TypeVar("R") @@ -86,7 +84,7 @@ def wrapper( # First try a cached connection and, if it's dead, a new connection. # Connections already cover against - # a) cleanly-existed master (ControlPersist reap); and + # a) cleanly-exited master (ControlPersist reap); and # b) stale control socket file left by killed master. # but we still want a fast retry in case master dies mid-request. for _ in range(2): From b5af3f4795ffdf1f0f7fd42cf7f3f95eae12fee8 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 13:57:20 +0500 Subject: [PATCH 22/25] Minor fixes --- .../_internal/server/services/runner/pool.py | 19 +++++++++++-------- .../_internal/server/services/runner/ssh.py | 1 + src/dstack/_internal/server/settings.py | 2 ++ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index 8e9827e00f..d31d92e3f0 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -66,6 +66,7 @@ class InstanceConnectionPool: NOTE: The pool is not currently intended for arbitrary ports forwarding, only for shim and runner ports. E.g. it cannot be used to forward services ports for probes or router-worker communication. This simplified model allows forwarding the same ports for the given host:port and reusing the connection across all calls. + TODO: Generalize to support arbitrary ports forwarding incl. job's ports. Incompatible with multiple server processes sharing the same server dir: connection dirs and control sockets are assumed to be owned by a single process. @@ -166,6 +167,16 @@ def _get_access_lock(self, key: InstanceConnectionKey) -> threading.Lock: class InstanceConnection: + """ + An SSH connection to instance's host sshd (VM-based) + or runner sshd (container-based) for forwarding shim and runner ports. + + The same control socket is used for all connections to the same hostname:port, + unless jrd overrides the runner port mapped on host (blocks case). + In case of blocks, each job establishes a separate connection with a different runner port forwarded. + TODO: Re-use the same SSH connection for all blocks via `-O forward`/`-O cancel`. + """ + def __init__( self, ssh_private_key: PrivateKeyOrPair, @@ -174,14 +185,6 @@ def __init__( ephemeral: bool = False, ) -> None: """ - An SSH connection to instance's host sshd (VM-based) - or runner sshd (container-based) for forwarding shim and runner ports. - - The same control socket is used for all connections to the same hostname:port, - unless jrd overrides the runner port mapped on host (blocks case). - In case of blocks, each job establishes a separate connection with a different runner port forwarded. - TODO: Re-use the same SSH connection for all blocks via `-O forward`/`-O cancel`. - Args: ephemeral: Creates a unique tmp dir for the UDS. Use when connection re-use is not needed. """ diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index dadc7c32b1..b1430fba6f 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -86,6 +86,7 @@ def wrapper( # Connections already cover against # a) cleanly-exited master (ControlPersist reap); and # b) stale control socket file left by killed master. + # (Because we cannot rely solely on connection errors from `func` – it may swallow the errors.) # but we still want a fast retry in case master dies mid-request. for _ in range(2): conn = instance_connection_pool.get_or_open( diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index a095f433f9..07443569e5 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -149,6 +149,8 @@ os.getenv("DSTACK_SERVER_LOG_QUOTA_PER_JOB_HOUR", 50 * 1024 * 1024) # 50 MB ) +# TODO: Replace DSTACK_SERVER_SSH_POOL_ENABLED with DSTACK_SERVER_SSH_POOL_DISABLE +# as pool becomes opt-out and document the env var. SERVER_SSH_POOL_ENABLED = os.getenv("DSTACK_SERVER_SSH_POOL_ENABLED") is not None SERVER_SSH_CONNECT_TIMEOUT = int(os.getenv("DSTACK_SERVER_SSH_CONNECT_TIMEOUT", 3)) From df496cc40e7534997117a4f2a2c1d3ae3b919a70 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 14:02:48 +0500 Subject: [PATCH 23/25] Fix typo --- src/dstack/_internal/server/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 07443569e5..2845687e23 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -149,7 +149,7 @@ os.getenv("DSTACK_SERVER_LOG_QUOTA_PER_JOB_HOUR", 50 * 1024 * 1024) # 50 MB ) -# TODO: Replace DSTACK_SERVER_SSH_POOL_ENABLED with DSTACK_SERVER_SSH_POOL_DISABLE +# TODO: Replace DSTACK_SERVER_SSH_POOL_ENABLED with DSTACK_SERVER_SSH_POOL_DISABLED # as pool becomes opt-out and document the env var. SERVER_SSH_POOL_ENABLED = os.getenv("DSTACK_SERVER_SSH_POOL_ENABLED") is not None SERVER_SSH_CONNECT_TIMEOUT = int(os.getenv("DSTACK_SERVER_SSH_CONNECT_TIMEOUT", 3)) From caa6eb87151bef5af74abde8ef4bcaed237ae509 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 5 Jun 2026 14:26:34 +0500 Subject: [PATCH 24/25] Run startup/teardown cleanup only if pool enabled --- src/dstack/_internal/server/app.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index f28da6296f..8b9e044776 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -169,7 +169,8 @@ async def lifespan(app: FastAPI): ) if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None: init_default_storage() - await run_async(instance_connection_pool.startup_cleanup) + if settings.SERVER_SSH_POOL_ENABLED: + await run_async(instance_connection_pool.startup_cleanup) scheduler = None pipeline_manager = None if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: @@ -212,7 +213,8 @@ async def lifespan(app: FastAPI): await gateway_connections_pool.remove_all() service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() await service_conn_pool.remove_all() - await run_async(instance_connection_pool.close_all) + if settings.SERVER_SSH_POOL_ENABLED: + await run_async(instance_connection_pool.close_all) await get_db().engine.dispose() # Let checked-out DB connections close as dispose() only closes checked-in connections await asyncio.sleep(3) From 86f4970ee82947c8367f5a8a03f274310af18d81 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 8 Jun 2026 10:25:39 +0500 Subject: [PATCH 25/25] Drop connection on instance termination --- .../background/pipeline_tasks/instances/termination.py | 7 +++++++ src/dstack/_internal/server/services/runner/pool.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py index eb1f3c8a39..a4bf6d3294 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py @@ -11,6 +11,10 @@ from dstack._internal.server.models import InstanceModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.instances import get_instance_provisioning_data +from dstack._internal.server.services.runner.pool import ( + InstanceConnectionKey, + instance_connection_pool, +) from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.logging import get_logger @@ -77,6 +81,9 @@ async def terminate_instance(instance_model: InstanceModel) -> ProcessResult: exc_info=not isinstance(exc, BackendError), ) + if job_provisioning_data is not None: + instance_connection_pool.drop(InstanceConnectionKey.from_jpd(job_provisioning_data)) + result.instance_update_map["deleted"] = True result.instance_update_map["deleted_at"] = NOW_PLACEHOLDER result.instance_update_map["finished_at"] = NOW_PLACEHOLDER diff --git a/src/dstack/_internal/server/services/runner/pool.py b/src/dstack/_internal/server/services/runner/pool.py index d31d92e3f0..b91d1d3125 100644 --- a/src/dstack/_internal/server/services/runner/pool.py +++ b/src/dstack/_internal/server/services/runner/pool.py @@ -45,7 +45,7 @@ class InstanceConnectionKey: @staticmethod def from_jpd( - jpd: JobProvisioningData, jrd: Optional[JobRuntimeData] + jpd: JobProvisioningData, jrd: Optional[JobRuntimeData] = None ) -> "InstanceConnectionKey": assert jpd.hostname is not None and jpd.ssh_port is not None container_to_host_port_map = InstanceConnection.get_container_to_host_port_map(jpd, jrd)