diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 4b861ea1a..8d57f73d6 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -318,7 +318,7 @@ def run_job( user=run.user, ssh_keys=[SSHKey(public=project_ssh_public_key.strip())], volumes=volumes, - reservation=run.run_spec.configuration.reservation, + reservation=job.job_spec.requirements.reservation, tags=run.run_spec.merged_profile.tags, ) instance_offer = instance_offer.copy() diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index cbed73c5b..3fb92d921 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -122,6 +122,10 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: replica_group_excludes["nvcc"] = True if all(g.privileged is None for g in replicas): replica_group_excludes["privileged"] = True + if all(g.spot_policy is None for g in replicas): + replica_group_excludes["spot_policy"] = True + if all(g.reservation is None for g in replicas): + replica_group_excludes["reservation"] = True if replica_group_excludes: configuration_excludes["replicas"] = {"__all__": replica_group_excludes} diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index bacf34e5a..a4cdd99af 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -25,6 +25,7 @@ from dstack._internal.core.models.profiles import ( ProfileParams, ProfileParamsConfig, + SpotPolicy, parse_duration, parse_off_duration, ) @@ -836,6 +837,24 @@ class ReplicaGroup(CoreModel): ResourcesSpec, Field(description="The resources requirements for replicas in this group"), ] = ResourcesSpec() + spot_policy: Annotated[ + Optional[SpotPolicy], + Field( + description=( + "The policy for provisioning spot or on-demand instances for replicas in this group:" + f" {list_enum_values_for_annotation(SpotPolicy)}" + ) + ), + ] = None + reservation: Annotated[ + Optional[str], + Field( + description=( + "The existing reservation to use for replicas in this group." + " Supports AWS Capacity Reservations, AWS Capacity Blocks, and GCP reservations" + ) + ), + ] = None commands: Annotated[ CommandsList, @@ -1144,7 +1163,7 @@ def validate_top_level_properties_with_replica_groups(cls, values): @root_validator() def validate_no_mixed_service_and_group_container_fields(cls, values): """ - When replicas is a list (image, docker, privileged) may be set + When replicas is a list, certain fields may be set at the service level OR in replica groups, never both. Mixing is rejected — including partial mixing, where only some groups set a field the service also sets — because it leaves precedence ambiguous. @@ -1179,6 +1198,16 @@ def validate_no_mixed_service_and_group_container_fields(cls, values): values.get("nvcc") is True, lambda g: g.nvcc is not None, ), + ( + "spot_policy", + values.get("spot_policy") is not None, + lambda g: g.spot_policy is not None, + ), + ( + "reservation", + values.get("reservation") is not None, + lambda g: g.reservation is not None, + ), ] for field, service_set, group_set in checks: diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 7eca97a91..a1137b9ce 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -127,6 +127,9 @@ def _default_max_duration(self) -> Optional[int]: def _spot_policy(self) -> SpotPolicy: pass + def _reservation(self) -> Optional[str]: + return self.run_spec.merged_profile.reservation + @abstractmethod def _ports(self) -> List[PortMapping]: pass @@ -334,7 +337,7 @@ def _requirements(self, jobs_per_replica: int) -> Requirements: resources=resources, max_price=self.run_spec.merged_profile.max_price, spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT), - reservation=self.run_spec.merged_profile.reservation, + reservation=self._reservation(), multinode=jobs_per_replica > 1, backend_options=self.run_spec.merged_profile.backend_options, ) diff --git a/src/dstack/_internal/server/services/jobs/configurators/service.py b/src/dstack/_internal/server/services/jobs/configurators/service.py index 7968c8ad7..45bc4c8f7 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/service.py +++ b/src/dstack/_internal/server/services/jobs/configurators/service.py @@ -113,7 +113,16 @@ def _default_max_duration(self) -> Optional[int]: return None def _spot_policy(self) -> SpotPolicy: + group = self._current_replica_group() + if group is not None and group.spot_policy is not None: + return group.spot_policy return self.run_spec.merged_profile.spot_policy or SpotPolicy.ONDEMAND + def _reservation(self) -> Optional[str]: + group = self._current_replica_group() + if group is not None and group.reservation is not None: + return group.reservation + return super()._reservation() + def _ports(self) -> List[PortMapping]: return [] diff --git a/src/dstack/_internal/server/services/offers.py b/src/dstack/_internal/server/services/offers.py index 569bc0724..6fd739f13 100644 --- a/src/dstack/_internal/server/services/offers.py +++ b/src/dstack/_internal/server/services/offers.py @@ -68,7 +68,7 @@ async def get_offers_by_requirements( backend_types = BACKENDS_WITH_INSTANCE_VOLUMES_SUPPORT backend_types = [b for b in backend_types if b in BACKENDS_WITH_INSTANCE_VOLUMES_SUPPORT] - if profile.reservation is not None: + if requirements.reservation is not None: if backend_types is None: backend_types = BACKENDS_WITH_RESERVATION_SUPPORT backend_types = [b for b in backend_types if b in BACKENDS_WITH_RESERVATION_SUPPORT] diff --git a/src/tests/_internal/core/models/test_configurations.py b/src/tests/_internal/core/models/test_configurations.py index 5027e2973..a602101a0 100644 --- a/src/tests/_internal/core/models/test_configurations.py +++ b/src/tests/_internal/core/models/test_configurations.py @@ -142,6 +142,46 @@ def test_replica_group_router_forbids_service_level_router(self): ): parse_run_configuration(conf) + def test_spot_policy_set_at_both_service_and_group_rejected(self): + with pytest.raises( + ConfigurationError, + match="`spot_policy` is set at both", + ): + parse_run_configuration( + { + "type": "service", + "port": 8000, + "spot_policy": "spot", + "replicas": [ + { + "count": 1, + "commands": ["x"], + "spot_policy": "on-demand", + }, + ], + } + ) + + def test_reservation_set_at_both_service_and_group_rejected(self): + with pytest.raises( + ConfigurationError, + match="`reservation` is set at both", + ): + parse_run_configuration( + { + "type": "service", + "port": 8000, + "image": "x", + "reservation": "svc-res", + "replicas": [ + { + "count": 1, + "reservation": "grp-res", + }, + ], + } + ) + @pytest.mark.parametrize("shell", [None, "sh", "bash", "/usr/bin/zsh"]) def test_shell_valid(self, shell: Optional[str]): conf = { diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index a46ec93f8..a6a03ff0b 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -1676,6 +1676,70 @@ def offers_by_requirements(requirements: Requirements): assert gpu_job_plan["offers"][0]["instance"]["resources"]["gpus"] != [] assert cpu_job_plan["offers"][0]["instance"]["resources"]["gpus"] == [] + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_service_reservation_group_filters_backends_by_reservation_support( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=None) + await create_fleet(session=session, project=project, spec=fleet_spec) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=ServiceConfiguration( + port=8080, + gateway=False, + image="nginx", + replicas=[ + ReplicaGroup( + name="reserved-group", + count=Range[int](min=1, max=1), + reservation="my-reservation-id", + ), + ReplicaGroup( + count=Range[int](min=1, max=1), + name="unreserved-group", + ), + ], + ), + ) + body = {"run_spec": json.loads(run_spec.json())} + + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + aws_backend_mock = Mock() + aws_backend_mock.TYPE = BackendType.AWS + aws_backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, price=2) + ] + verda_backend_mock = Mock() + verda_backend_mock.TYPE = BackendType.VERDA + verda_backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.VERDA, price=1) + ] + m.return_value = [aws_backend_mock, verda_backend_mock] + + response = await client.post( + f"/api/project/{project.name}/runs/get_plan", + headers=get_auth_headers(user.token), + json=body, + ) + assert response.status_code == 200, response.json() + reserved_job_plan, unreserved_job_plan = response.json()["job_plans"] + + # Verda offer not included for `reserved-group`, since Verda does not support reservations + assert reserved_job_plan["offers"][0]["backend"] == "aws" + assert len(reserved_job_plan["offers"]) == 1 + + assert unreserved_job_plan["offers"][0]["backend"] == "verda" + assert unreserved_job_plan["offers"][1]["backend"] == "aws" + assert len(unreserved_job_plan["offers"]) == 2 + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_returns_run_plan_docker_true( diff --git a/src/tests/_internal/server/services/jobs/configurators/test_service.py b/src/tests/_internal/server/services/jobs/configurators/test_service.py index 9d5c9bdf7..a8410fcfa 100644 --- a/src/tests/_internal/server/services/jobs/configurators/test_service.py +++ b/src/tests/_internal/server/services/jobs/configurators/test_service.py @@ -10,6 +10,7 @@ ReplicaGroup, ServiceConfiguration, ) +from dstack._internal.core.models.profiles import SpotPolicy from dstack._internal.core.models.resources import Range from dstack._internal.core.models.services import OpenAIChatModel from dstack._internal.server.services.docker import ImageConfig @@ -118,7 +119,7 @@ def _make_run_spec(replicas, **service_kwargs): @pytest.mark.usefixtures("image_config_mock") class TestPerGroupOverrides: """Verifies that ServiceJobConfigurator picks up per-replica-group - image-source fields (image, docker, python, nvcc, privileged).""" + image-source fields (image, docker, python, nvcc, privileged, etc).""" async def test_image_name_uses_group_image(self): run_spec = _make_run_spec( @@ -331,3 +332,103 @@ async def test_user_does_not_lookup_for_group_docker(self, monkeypatch: pytest.M configurator = ServiceJobConfigurator(run_spec, replica_group_name="a") await configurator._user() mock_get_image_config.assert_not_called() + + async def test_spot_policy_uses_group_value(self): + run_spec = _make_run_spec( + replicas=[ + ReplicaGroup( + name="a", + count=Range(min=1, max=1), + commands=["x"], + spot_policy=SpotPolicy.SPOT, + ) + ], + ) + configurator = ServiceJobConfigurator(run_spec, replica_group_name="a") + assert configurator._spot_policy() == SpotPolicy.SPOT + + async def test_spot_policy_defaults_to_ondemand_when_group_unset(self): + run_spec = _make_run_spec( + replicas=[ + ReplicaGroup( + name="a", + count=Range(min=1, max=1), + commands=["x"], + ) + ], + ) + configurator = ServiceJobConfigurator(run_spec, replica_group_name="a") + assert configurator._spot_policy() == SpotPolicy.ONDEMAND + + async def test_different_groups_different_spot_policies(self): + run_spec = _make_run_spec( + replicas=[ + ReplicaGroup( + name="spot", + count=Range(min=1, max=1), + commands=["x"], + spot_policy=SpotPolicy.SPOT, + ), + ReplicaGroup( + name="od", + count=Range(min=1, max=1), + commands=["y"], + spot_policy=SpotPolicy.ONDEMAND, + ), + ], + ) + assert ( + ServiceJobConfigurator(run_spec, replica_group_name="spot")._spot_policy() + == SpotPolicy.SPOT + ) + assert ( + ServiceJobConfigurator(run_spec, replica_group_name="od")._spot_policy() + == SpotPolicy.ONDEMAND + ) + + async def test_reservation_uses_group_value(self): + run_spec = _make_run_spec( + replicas=[ + ReplicaGroup( + name="a", + count=Range(min=1, max=1), + commands=["x"], + reservation="my-reservation", + ) + ], + ) + configurator = ServiceJobConfigurator(run_spec, replica_group_name="a") + assert configurator._reservation() == "my-reservation" + + async def test_reservation_defaults_to_none_when_group_unset(self): + run_spec = _make_run_spec( + replicas=[ + ReplicaGroup( + name="a", + count=Range(min=1, max=1), + commands=["x"], + ) + ], + ) + configurator = ServiceJobConfigurator(run_spec, replica_group_name="a") + assert configurator._reservation() is None + + async def test_different_groups_different_reservations(self): + run_spec = _make_run_spec( + replicas=[ + ReplicaGroup( + name="a", + count=Range(min=1, max=1), + commands=["x"], + reservation="res-a", + ), + ReplicaGroup( + name="b", + count=Range(min=1, max=1), + commands=["y"], + reservation="res-b", + ), + ], + ) + assert ServiceJobConfigurator(run_spec, replica_group_name="a")._reservation() == "res-a" + assert ServiceJobConfigurator(run_spec, replica_group_name="b")._reservation() == "res-b"