diff --git a/mkdocs/snippets/kubernetes/dstack-backend-role.yaml b/mkdocs/snippets/kubernetes/dstack-backend-role.yaml index acaa438be..7cc37fbc1 100644 --- a/mkdocs/snippets/kubernetes/dstack-backend-role.yaml +++ b/mkdocs/snippets/kubernetes/dstack-backend-role.yaml @@ -6,7 +6,7 @@ metadata: rules: - apiGroups: [""] resources: ["pods"] - verbs: ["get", "create", "delete"] + verbs: ["get", "watch", "create", "delete"] - apiGroups: [""] resources: ["services"] verbs: ["get", "create", "delete"] diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index f43f3e134..062e458b4 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -3,6 +3,7 @@ import subprocess import tempfile import time +from contextlib import ExitStack from decimal import Decimal from enum import Enum from typing import List, Optional @@ -63,6 +64,8 @@ get_api_from_kubeconfig_dict, kubeconfig_data_to_kubeconfig_dict, kubeconfig_dict_to_kubeconfig, + try_delete_object_if_exists, + watch_events, ) from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ComputeError, ProvisioningError @@ -80,7 +83,14 @@ from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec from dstack._internal.core.models.routers import AnyGatewayRouterConfig -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run +from dstack._internal.core.models.runs import ( + Job, + JobProvisioningData, + JobSpec, + Requirements, + Run, + RunSpec, +) from dstack._internal.core.models.volumes import ( InstanceMountPoint, KubernetesVolumeConfiguration, @@ -97,6 +107,8 @@ JUMP_POD_SSH_PORT = 22 JUMP_POD_USER = "root" +JOB_POD_SCHEDULING_TIMEOUT = 10 + class Operator(str, Enum): EXISTS = "Exists" @@ -168,227 +180,117 @@ def run_job( volumes: list[Volume], placement_group: Optional[PlacementGroup], ) -> JobProvisioningData: - instance_name = generate_unique_instance_name_for_job(run, job) - assert run.run_spec.ssh_key_pub is not None - commands = get_docker_commands( - [run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] - ) + api = self.api + namespace = self.namespace + # There is one jump pod per project that is used as an ssh proxy jump to connect # to all job pods of the same project. # The service is created here and configured later in update_provisioning_data() jump_pod_name = f"dstack-{run.project_name}-ssh-jump-pod" jump_pod_service_name = _get_pod_service_name(jump_pod_name) _create_jump_pod_service_if_not_exists( - api=self.api, - namespace=self.namespace, + api=api, + namespace=namespace, jump_pod_name=jump_pod_name, jump_pod_service_name=jump_pod_service_name, jump_pod_port=self.proxy_jump.port, project_ssh_public_key=project_ssh_public_key.strip(), ) - image_pull_secrets: Optional[list[client.V1LocalObjectReference]] = None - if job.job_spec.registry_auth is not None: - registry_auth_secret_name = _get_registry_auth_secret_name(instance_name) - dockerconfigjson = build_dockerconfigjson( - image_name=job.job_spec.image_name, - username=job.job_spec.registry_auth.username, - password=job.job_spec.registry_auth.password, - ) - registry_auth_secret = client.V1Secret( - metadata=client.V1ObjectMeta(name=registry_auth_secret_name), - type="kubernetes.io/dockerconfigjson", - string_data={".dockerconfigjson": dockerconfigjson}, - ) - self.api.create_namespaced_secret( - namespace=self.namespace, - body=registry_auth_secret, - ) - image_pull_secrets = [client.V1LocalObjectReference(name=registry_auth_secret_name)] - - resources_requests: dict[str, str] = {} - resources_limits: dict[str, str] = {} - node_affinity: Optional[client.V1NodeAffinity] = None - tolerations: list[client.V1Toleration] = [] - volumes_: list[client.V1Volume] = [] - volume_mounts: list[client.V1VolumeMount] = [] - - resources_spec = job.job_spec.requirements.resources - assert isinstance(resources_spec.cpu, CPUSpec) - if (cpu_min := resources_spec.cpu.count.min) is not None: - resources_requests["cpu"] = str(cpu_min) - if (cpu_max := resources_spec.cpu.count.max) is not None: - resources_limits["cpu"] = str(cpu_max) - if (gpu_spec := resources_spec.gpu) is not None: - if (gpu_request := get_gpu_request_from_gpu_spec(gpu_spec)) > 0: - gpu_resource, node_affinity, node_taint = _get_pod_spec_parameters_for_gpu( - self.api, gpu_spec + pod_name = generate_unique_instance_name_for_job(run, job) + registry_auth_secret_name: Optional[str] = None + with ExitStack() as exit_stack: + if job.job_spec.registry_auth is not None: + registry_auth_secret_name = _get_registry_auth_secret_name(pod_name) + _create_registry_auth_secret( + api=api, + namespace=namespace, + secret_name=registry_auth_secret_name, + image_name=job.job_spec.image_name, + username=job.job_spec.registry_auth.username, + password=job.job_spec.registry_auth.password, ) - logger.debug("Requesting GPU resource: %s=%d", gpu_resource, gpu_request) - resources_requests[gpu_resource] = str(gpu_request) - # Limit must be set (GPU resources cannot be overcommitted) - # and must be equal to request. - resources_limits[gpu_resource] = str(gpu_request) - # It should be NoSchedule, but we also add NoExecute toleration just in case. - for effect in [TaintEffect.NO_SCHEDULE, TaintEffect.NO_EXECUTE]: - tolerations.append( - client.V1Toleration( - key=node_taint, operator=Operator.EXISTS, effect=effect - ) - ) - if (memory_min := resources_spec.memory.min) is not None: - resources_requests["memory"] = format_memory(memory_min) - if (memory_max := resources_spec.memory.max) is not None: - resources_limits["memory"] = format_memory(memory_max) - if (disk_spec := resources_spec.disk) is not None: - if (disk_min := disk_spec.size.min) is not None: - resources_requests["ephemeral-storage"] = format_memory(disk_min) - if (disk_max := disk_spec.size.max) is not None: - resources_limits["ephemeral-storage"] = format_memory(disk_max) - if (shm_size := resources_spec.shm_size) is not None: - shm_volume_name = "dev-shm" - volumes_.append( - client.V1Volume( - name=shm_volume_name, - empty_dir=client.V1EmptyDirVolumeSource( - medium="Memory", - size_limit=format_memory(shm_size), - ), + exit_stack.callback( + try_delete_object_if_exists, + api.delete_namespaced_secret, + namespace=namespace, + name=registry_auth_secret_name, + description="registry auth secret", + should_delete_manually_if_failed=True, ) + + assert run.run_spec.ssh_key_pub is not None + authorized_keys = [run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] + _create_job_pod( + api=api, + namespace=namespace, + pod_name=pod_name, + registry_auth_secret_name=registry_auth_secret_name, + run_spec=run.run_spec, + job_spec=job.job_spec, + volumes=volumes, + authorized_keys=authorized_keys, ) - volume_mounts.append( - client.V1VolumeMount( - name=shm_volume_name, - mount_path="/dev/shm", - ) + exit_stack.callback( + try_delete_object_if_exists, + api.delete_namespaced_pod, + namespace=namespace, + name=pod_name, + description="pod", + should_delete_manually_if_failed=True, ) - - volume_name_path_map: dict[str, str] = {} - mount_points = job.job_spec.volumes - if mount_points is None: - # Legacy JobSpec without volumes - mount_points = run.run_spec.configuration.volumes - for mount_point in mount_points: - if isinstance(mount_point, VolumeMountPoint): - if isinstance(mount_point.name, str): - volume_names = [mount_point.name] - else: - volume_names = mount_point.name - for volume_name in volume_names: - volume_name_path_map[volume_name] = mount_point.path - elif isinstance(mount_point, InstanceMountPoint): - # "Must be a DNS_LABEL and unique within the pod" - volume_name = generate_unique_name( - prefix="host-path", max_length=OBJECT_NAME_MAX_LENGTH - ) - volumes_.append( - client.V1Volume( - name=volume_name, - host_path=client.V1HostPathVolumeSource( - path=mount_point.instance_path, - type="DirectoryOrCreate", - ), - ), + is_pod_scheduled_or_finished, pod_phase = _wait_for_pod_scheduled_or_finished( + api=api, + namespace=namespace, + pod_name=pod_name, + timeout_seconds=JOB_POD_SCHEDULING_TIMEOUT, + ) + if not is_pod_scheduled_or_finished: + reason, message = _get_unscheduled_pod_reason_message( + api=api, + namespace=namespace, + pod_name=pod_name, ) - volume_mounts.append( - client.V1VolumeMount( - name=volume_name, - mount_path=mount_point.path, - ) + raise ComputeError( + f"Pod {pod_name} was not scheduled:" + f" {reason or 'unknown reason'}: {message or 'no message'}" ) - else: - assert False, f"unexpected mount point: {mount_point!r}" - for volume in volumes: - assert isinstance(volume.configuration, KubernetesVolumeConfiguration) - pvc_name = volume.volume_id - assert pvc_name is not None, f"missing PVC name: {volume!r}" - mount_path = volume_name_path_map.get(volume.name) - assert mount_path is not None, f"missing mount path: {volume!r}" - volume_name = generate_unique_name(prefix="pvc", max_length=OBJECT_NAME_MAX_LENGTH) - volumes_.append( - client.V1Volume( - name=volume_name, - persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( - claim_name=pvc_name, + if pod_phase is not None and pod_phase.is_finished(): + raise ComputeError(f"Pod {pod_name} already finished: {pod_phase}") + + pod_service_name = _get_pod_service_name(pod_name) + api.create_namespaced_service( + namespace=namespace, + body=client.V1Service( + metadata=client.V1ObjectMeta(name=pod_service_name), + spec=client.V1ServiceSpec( + type="ClusterIP", + selector={"app.kubernetes.io/name": pod_name}, + ports=[client.V1ServicePort(port=DSTACK_RUNNER_SSH_PORT)], ), ), ) - volume_mounts.append( - client.V1VolumeMount( - name=volume_name, - mount_path=mount_path, - read_only=volume.configuration.read_only, - recursive_read_only="IfPossible" if volume.configuration.read_only else None, - ) + exit_stack.callback( + try_delete_object_if_exists, + api.delete_namespaced_service, + namespace=namespace, + name=pod_service_name, + description="pod service", + should_delete_manually_if_failed=True, ) - pod = client.V1Pod( - metadata=client.V1ObjectMeta( - name=instance_name, - labels={"app.kubernetes.io/name": instance_name}, - ), - spec=client.V1PodSpec( - containers=[ - client.V1Container( - name=f"{instance_name}-container", - image=job.job_spec.image_name, - command=["/bin/sh"], - args=["-c", " && ".join(commands)], - ports=[ - client.V1ContainerPort( - container_port=DSTACK_RUNNER_SSH_PORT, - ) - ], - security_context=client.V1SecurityContext( - run_as_user=0, - run_as_group=0, - privileged=job.job_spec.privileged, - capabilities=client.V1Capabilities( - add=[ - # Allow to increase hard resource limits, see getrlimit(2) - "SYS_RESOURCE", - ], - ), - ), - resources=client.V1ResourceRequirements( - requests=resources_requests, - limits=resources_limits, - ), - volume_mounts=volume_mounts, - ) - ], - image_pull_secrets=image_pull_secrets, - affinity=client.V1Affinity( - node_affinity=node_affinity, - ), - tolerations=tolerations, - volumes=volumes_, - ), - ) - self.api.create_namespaced_pod( - namespace=self.namespace, - body=pod, - ) - self.api.create_namespaced_service( - namespace=self.namespace, - body=client.V1Service( - metadata=client.V1ObjectMeta(name=_get_pod_service_name(instance_name)), - spec=client.V1ServiceSpec( - type="ClusterIP", - selector={"app.kubernetes.io/name": instance_name}, - ports=[client.V1ServicePort(port=DSTACK_RUNNER_SSH_PORT)], - ), - ), - ) + # Cancel all cleanup callbacks + exit_stack.pop_all() backend_data = KubernetesBackendData( jump_pod_name=jump_pod_name, jump_pod_service_name=jump_pod_service_name, user_ssh_public_key=run.run_spec.ssh_key_pub.strip(), ) + return JobProvisioningData( backend=instance_offer.backend, - instance_id=instance_name, + instance_id=pod_name, region=instance_offer.region, price=instance_offer.price, username="root", @@ -470,27 +372,30 @@ def update_provisioning_data( def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): - call_api_method( - self.api.delete_namespaced_service, - expected=404, - name=_get_pod_service_name(instance_id), - namespace=self.namespace, - body=client.V1DeleteOptions(), - ) - call_api_method( - self.api.delete_namespaced_pod, - expected=404, - name=instance_id, - namespace=self.namespace, - body=client.V1DeleteOptions(), - ) - call_api_method( - self.api.delete_namespaced_secret, - expected=404, - name=_get_registry_auth_secret_name(instance_id), - namespace=self.namespace, - body=client.V1DeleteOptions(), - ) + api = self.api + namespace = self.namespace + deleted = [ + try_delete_object_if_exists( + api.delete_namespaced_service, + namespace=namespace, + name=_get_pod_service_name(instance_id), + description="pod service", + ), + try_delete_object_if_exists( + api.delete_namespaced_pod, + namespace=namespace, + name=instance_id, + description="pod", + ), + try_delete_object_if_exists( + api.delete_namespaced_secret, + namespace=namespace, + name=_get_registry_auth_secret_name(instance_id), + description="registry auth secret", + ), + ] + if not all(deleted): + raise ComputeError("Not all objects were deleted, check logs") def create_gateway( self, @@ -1052,6 +957,258 @@ def _get_jump_pod_commands(authorized_keys: list[str]) -> list[str]: return commands +def _create_registry_auth_secret( + api: client.CoreV1Api, + namespace: str, + secret_name: str, + image_name: str, + username: str, + password: str, +) -> None: + dockerconfigjson = build_dockerconfigjson( + image_name=image_name, + username=username, + password=password, + ) + secret = client.V1Secret( + metadata=client.V1ObjectMeta(name=secret_name), + type="kubernetes.io/dockerconfigjson", + string_data={".dockerconfigjson": dockerconfigjson}, + ) + api.create_namespaced_secret( + namespace=namespace, + body=secret, + ) + + +def _create_job_pod( + api: client.CoreV1Api, + namespace: str, + pod_name: str, + registry_auth_secret_name: Optional[str], + run_spec: RunSpec, + job_spec: JobSpec, + volumes: list[Volume], + authorized_keys: list[str], +) -> None: + resources_requests: dict[str, str] = {} + resources_limits: dict[str, str] = {} + node_affinity: Optional[client.V1NodeAffinity] = None + tolerations: list[client.V1Toleration] = [] + volumes_: list[client.V1Volume] = [] + volume_mounts: list[client.V1VolumeMount] = [] + + resources_spec = job_spec.requirements.resources + assert isinstance(resources_spec.cpu, CPUSpec) + if (cpu_min := resources_spec.cpu.count.min) is not None: + resources_requests["cpu"] = str(cpu_min) + if (cpu_max := resources_spec.cpu.count.max) is not None: + resources_limits["cpu"] = str(cpu_max) + if (gpu_spec := resources_spec.gpu) is not None: + if (gpu_request := get_gpu_request_from_gpu_spec(gpu_spec)) > 0: + gpu_resource, node_affinity, node_taint = _get_pod_spec_parameters_for_gpu( + api, gpu_spec + ) + logger.debug("Requesting GPU resource: %s=%d", gpu_resource, gpu_request) + resources_requests[gpu_resource] = str(gpu_request) + # Limit must be set (GPU resources cannot be overcommitted) + # and must be equal to request. + resources_limits[gpu_resource] = str(gpu_request) + # It should be NoSchedule, but we also add NoExecute toleration just in case. + for effect in [TaintEffect.NO_SCHEDULE, TaintEffect.NO_EXECUTE]: + tolerations.append( + client.V1Toleration(key=node_taint, operator=Operator.EXISTS, effect=effect) + ) + if (memory_min := resources_spec.memory.min) is not None: + resources_requests["memory"] = format_memory(memory_min) + if (memory_max := resources_spec.memory.max) is not None: + resources_limits["memory"] = format_memory(memory_max) + if (disk_spec := resources_spec.disk) is not None: + if (disk_min := disk_spec.size.min) is not None: + resources_requests["ephemeral-storage"] = format_memory(disk_min) + if (disk_max := disk_spec.size.max) is not None: + resources_limits["ephemeral-storage"] = format_memory(disk_max) + if (shm_size := resources_spec.shm_size) is not None: + shm_volume_name = "dev-shm" + volumes_.append( + client.V1Volume( + name=shm_volume_name, + empty_dir=client.V1EmptyDirVolumeSource( + medium="Memory", + size_limit=format_memory(shm_size), + ), + ) + ) + volume_mounts.append( + client.V1VolumeMount( + name=shm_volume_name, + mount_path="/dev/shm", + ) + ) + + volume_name_path_map: dict[str, str] = {} + mount_points = job_spec.volumes + if mount_points is None: + # Legacy JobSpec without volumes + mount_points = run_spec.configuration.volumes + for mount_point in mount_points: + if isinstance(mount_point, VolumeMountPoint): + if isinstance(mount_point.name, str): + volume_names = [mount_point.name] + else: + volume_names = mount_point.name + for volume_name in volume_names: + volume_name_path_map[volume_name] = mount_point.path + elif isinstance(mount_point, InstanceMountPoint): + # "Must be a DNS_LABEL and unique within the pod" + volume_name = generate_unique_name( + prefix="host-path", max_length=OBJECT_NAME_MAX_LENGTH + ) + volumes_.append( + client.V1Volume( + name=volume_name, + host_path=client.V1HostPathVolumeSource( + path=mount_point.instance_path, + type="DirectoryOrCreate", + ), + ), + ) + volume_mounts.append( + client.V1VolumeMount( + name=volume_name, + mount_path=mount_point.path, + ) + ) + else: + assert False, f"unexpected mount point: {mount_point!r}" + for volume in volumes: + assert isinstance(volume.configuration, KubernetesVolumeConfiguration) + pvc_name = volume.volume_id + assert pvc_name is not None, f"missing PVC name: {volume!r}" + mount_path = volume_name_path_map.get(volume.name) + assert mount_path is not None, f"missing mount path: {volume!r}" + volume_name = generate_unique_name(prefix="pvc", max_length=OBJECT_NAME_MAX_LENGTH) + volumes_.append( + client.V1Volume( + name=volume_name, + persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( + claim_name=pvc_name, + ), + ), + ) + volume_mounts.append( + client.V1VolumeMount( + name=volume_name, + mount_path=mount_path, + read_only=volume.configuration.read_only, + recursive_read_only="IfPossible" if volume.configuration.read_only else None, + ) + ) + + pod = client.V1Pod( + metadata=client.V1ObjectMeta( + name=pod_name, + labels={"app.kubernetes.io/name": pod_name}, + ), + spec=client.V1PodSpec( + containers=[ + client.V1Container( + name=f"{pod_name}-container", + image=job_spec.image_name, + command=["/bin/sh"], + args=["-c", " && ".join(get_docker_commands(authorized_keys))], + ports=[ + client.V1ContainerPort( + container_port=DSTACK_RUNNER_SSH_PORT, + ) + ], + security_context=client.V1SecurityContext( + run_as_user=0, + run_as_group=0, + privileged=job_spec.privileged, + capabilities=client.V1Capabilities( + add=[ + # Allow to increase hard resource limits, see getrlimit(2) + "SYS_RESOURCE", + ], + ), + ), + resources=client.V1ResourceRequirements( + requests=resources_requests, + limits=resources_limits, + ), + volume_mounts=volume_mounts, + ) + ], + image_pull_secrets=( + [client.V1LocalObjectReference(name=registry_auth_secret_name)] + if registry_auth_secret_name is not None + else None + ), + affinity=client.V1Affinity( + node_affinity=node_affinity, + ), + tolerations=tolerations, + volumes=volumes_, + ), + ) + api.create_namespaced_pod( + namespace=namespace, + body=pod, + ) + + +def _wait_for_pod_scheduled_or_finished( + api: client.CoreV1Api, + namespace: str, + pod_name: str, + timeout_seconds: int, +) -> tuple[bool, Optional[PodPhase]]: + # We wait until container_statuses is populated rather than checking the PodScheduled + # condition or spec.node_name. container_statuses is set by the kubelet only after it + # has accepted the bound pod and started creating containers, so it implies both that + # the scheduler confirmed capacity and that the assigned node is actually Ready and + # working on the pod. + pod_phase: Optional[PodPhase] = None + with watch_events( + api.list_namespaced_pod, + namespace=namespace, + field_selector=f"metadata.name={pod_name}", + timeout_seconds=timeout_seconds, + ) as event_iter: + for _, pod in event_iter: + pod_status = pod.status + if pod_status is None: + continue + if pod_status.phase is not None: + pod_phase = PodPhase(pod_status.phase) + else: + pod_phase = None + if pod_status.container_statuses is not None: + return True, pod_phase + if pod_phase is not None and pod_phase is not PodPhase.PENDING: + return True, pod_phase + return False, pod_phase + + +def _get_unscheduled_pod_reason_message( + api: client.CoreV1Api, + namespace: str, + pod_name: str, +) -> tuple[Optional[str], Optional[str]]: + pod = call_api_method( + api.read_namespaced_pod, + expected=404, + name=pod_name, + namespace=namespace, + ) + if pod is not None and pod.status is not None and pod.status.conditions: + for cond in pod.status.conditions: + if cond.type == "PodScheduled" and cond.status == "False": + return cond.reason, cond.message + return None, None + + def _wait_for_load_balancer_address( api: client.CoreV1Api, namespace: str, diff --git a/src/dstack/_internal/core/backends/kubernetes/utils.py b/src/dstack/_internal/core/backends/kubernetes/utils.py index 905d0b1b7..b1ea5c7cb 100644 --- a/src/dstack/_internal/core/backends/kubernetes/utils.py +++ b/src/dstack/_internal/core/backends/kubernetes/utils.py @@ -1,16 +1,25 @@ -from typing import Annotated, Callable, Optional, TypeVar, Union +from collections.abc import Generator +from contextlib import contextmanager +from typing import Annotated, Any, Callable, Generic, Literal, Optional, Protocol, TypeVar, Union import yaml -from kubernetes.client import CoreV1Api +from kubernetes.client import CoreV1Api, V1Status from kubernetes.client.exceptions import ApiException from kubernetes.config import ( # XXX: This function is missing in the stubs package new_client_from_config_dict, # pyright: ignore[reportAttributeAccessIssue] ) + +# XXX: The watch module is missing in the stubs package +from kubernetes.watch import Watch # pyright: ignore[reportMissingImports] from pydantic import Field -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, TypedDict +from urllib3.exceptions import HTTPError from dstack._internal.core.models.common import CoreModel +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) T = TypeVar("T") P = ParamSpec("P") @@ -97,3 +106,102 @@ def call_api_method( if e.status not in expected: raise return None + + +class NamespacedNameMethod(Protocol): + def __call__(self, name: str, namespace: str) -> Any: ... + + +def try_delete_object_if_exists( + method: NamespacedNameMethod, + *, + namespace: str, + name: str, + description: str, + should_delete_manually_if_failed: bool = False, +) -> bool: + try: + call_api_method( + method, + expected=404, + namespace=namespace, + name=name, + ) + except (HTTPError, ApiException) as e: + if should_delete_manually_if_failed: + logger.exception( + "Failed to delete %s %s in namespace %s. Please delete it manually", + description, + name, + namespace, + ) + else: + logger.warning( + "Failed to delete %s %s in namespace %s: %s: %s", + description, + name, + namespace, + e.__class__.__name__, + e, + ) + return False + return True + + +class ObjectList(Protocol[T]): + items: list[T] + + +@contextmanager +def watch_events( + method: Callable[P, ObjectList[T]], *args: P.args, **kwargs: P.kwargs +) -> Generator[Generator[tuple[str, T], None, None], None, None]: + watch = Watch() + gen = _watch_events_gen(watch.stream(method, *args, **kwargs)) + try: + yield gen + finally: + gen.close() + watch.stop() + + +class _StateEventDict(TypedDict, Generic[T]): + type: Literal["ADDED", "MODIFIED", "DELETED"] + object: T + + +class _BookmarkEventDict(TypedDict, Generic[T]): + type: Literal["BOOKMARK"] + # The object is a minimal instance of the watched resource's type -- same kind and apiVersion, + # but only metadata.resourceVersion is populated. Everything else is empty or zero-valued. + object: T + + +class _ErrorEventDict(TypedDict): + type: Literal["ERROR"] + object: V1Status + + +def _watch_events_gen( + gen: Generator[Union[_StateEventDict[T], _BookmarkEventDict[T], _ErrorEventDict], None, None], +) -> Generator[tuple[str, T], None, None]: + try: + for event in gen: + match event["type"]: + case "ADDED" | "MODIFIED" | "DELETED": + yield event["type"], event["object"] + case "BOOKMARK": + pass + case "ERROR": + status = event["object"] + logger.warning( + "Got ERROR event (status=%s reason=%s code=%s): %s", + status.status, + status.reason, + status.code, + status.message, + ) + case _: + logger.warning("Got unexpected event: %s", event) + finally: + gen.close()