Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def run_job(
user=run.user,
ssh_keys=[SSHKey(public=project_ssh_public_key.strip())],
volumes=volumes,
reservation=run.run_spec.configuration.reservation,
reservation=job.job_spec.requirements.reservation,
tags=run.run_spec.merged_profile.tags,
)
instance_offer = instance_offer.copy()
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/core/compatibility/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType:
replica_group_excludes["nvcc"] = True
if all(g.privileged is None for g in replicas):
replica_group_excludes["privileged"] = True
if all(g.spot_policy is None for g in replicas):
replica_group_excludes["spot_policy"] = True
if all(g.reservation is None for g in replicas):
replica_group_excludes["reservation"] = True
if replica_group_excludes:
configuration_excludes["replicas"] = {"__all__": replica_group_excludes}

Expand Down
31 changes: 30 additions & 1 deletion src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dstack._internal.core.models.profiles import (
ProfileParams,
ProfileParamsConfig,
SpotPolicy,
parse_duration,
parse_off_duration,
)
Expand Down Expand Up @@ -836,6 +837,24 @@ class ReplicaGroup(CoreModel):
ResourcesSpec,
Field(description="The resources requirements for replicas in this group"),
] = ResourcesSpec()
spot_policy: Annotated[
Optional[SpotPolicy],
Field(
description=(
"The policy for provisioning spot or on-demand instances for replicas in this group:"
f" {list_enum_values_for_annotation(SpotPolicy)}"
)
),
] = None
reservation: Annotated[
Optional[str],
Field(
description=(
"The existing reservation to use for replicas in this group."
" Supports AWS Capacity Reservations, AWS Capacity Blocks, and GCP reservations"
)
),
] = None

commands: Annotated[
CommandsList,
Expand Down Expand Up @@ -1144,7 +1163,7 @@ def validate_top_level_properties_with_replica_groups(cls, values):
@root_validator()
def validate_no_mixed_service_and_group_container_fields(cls, values):
"""
When replicas is a list (image, docker, privileged) may be set
When replicas is a list, certain fields may be set
at the service level OR in replica groups, never both. Mixing is
rejected — including partial mixing, where only some groups set a
field the service also sets — because it leaves precedence ambiguous.
Expand Down Expand Up @@ -1179,6 +1198,16 @@ def validate_no_mixed_service_and_group_container_fields(cls, values):
values.get("nvcc") is True,
lambda g: g.nvcc is not None,
),
(
"spot_policy",
values.get("spot_policy") is not None,
lambda g: g.spot_policy is not None,
),
(
"reservation",
values.get("reservation") is not None,
lambda g: g.reservation is not None,
),
]

for field, service_set, group_set in checks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def _default_max_duration(self) -> Optional[int]:
def _spot_policy(self) -> SpotPolicy:
pass

def _reservation(self) -> Optional[str]:
return self.run_spec.merged_profile.reservation

@abstractmethod
def _ports(self) -> List[PortMapping]:
pass
Expand Down Expand Up @@ -334,7 +337,7 @@ def _requirements(self, jobs_per_replica: int) -> Requirements:
resources=resources,
max_price=self.run_spec.merged_profile.max_price,
spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT),
reservation=self.run_spec.merged_profile.reservation,
reservation=self._reservation(),
multinode=jobs_per_replica > 1,
backend_options=self.run_spec.merged_profile.backend_options,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,16 @@ def _default_max_duration(self) -> Optional[int]:
return None

def _spot_policy(self) -> SpotPolicy:
group = self._current_replica_group()
if group is not None and group.spot_policy is not None:
return group.spot_policy
return self.run_spec.merged_profile.spot_policy or SpotPolicy.ONDEMAND

def _reservation(self) -> Optional[str]:
group = self._current_replica_group()
if group is not None and group.reservation is not None:
return group.reservation
return super()._reservation()

def _ports(self) -> List[PortMapping]:
return []
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/services/offers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def get_offers_by_requirements(
backend_types = BACKENDS_WITH_INSTANCE_VOLUMES_SUPPORT
backend_types = [b for b in backend_types if b in BACKENDS_WITH_INSTANCE_VOLUMES_SUPPORT]

if profile.reservation is not None:
if requirements.reservation is not None:
if backend_types is None:
backend_types = BACKENDS_WITH_RESERVATION_SUPPORT
backend_types = [b for b in backend_types if b in BACKENDS_WITH_RESERVATION_SUPPORT]
Expand Down
40 changes: 40 additions & 0 deletions src/tests/_internal/core/models/test_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,46 @@ def test_replica_group_router_forbids_service_level_router(self):
):
parse_run_configuration(conf)

def test_spot_policy_set_at_both_service_and_group_rejected(self):
with pytest.raises(
ConfigurationError,
match="`spot_policy` is set at both",
):
parse_run_configuration(
{
"type": "service",
"port": 8000,
"spot_policy": "spot",
"replicas": [
{
"count": 1,
"commands": ["x"],
"spot_policy": "on-demand",
},
],
}
)

def test_reservation_set_at_both_service_and_group_rejected(self):
with pytest.raises(
ConfigurationError,
match="`reservation` is set at both",
):
parse_run_configuration(
{
"type": "service",
"port": 8000,
"image": "x",
"reservation": "svc-res",
"replicas": [
{
"count": 1,
"reservation": "grp-res",
},
],
}
)

@pytest.mark.parametrize("shell", [None, "sh", "bash", "/usr/bin/zsh"])
def test_shell_valid(self, shell: Optional[str]):
conf = {
Expand Down
64 changes: 64 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,70 @@ def offers_by_requirements(requirements: Requirements):
assert gpu_job_plan["offers"][0]["instance"]["resources"]["gpus"] != []
assert cpu_job_plan["offers"][0]["instance"]["resources"]["gpus"] == []

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_service_reservation_group_filters_backends_by_reservation_support(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
fleet_spec = get_fleet_spec()
fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=None)
await create_fleet(session=session, project=project, spec=fleet_spec)
repo = await create_repo(session=session, project_id=project.id)
run_spec = get_run_spec(
repo_id=repo.name,
configuration=ServiceConfiguration(
port=8080,
gateway=False,
image="nginx",
replicas=[
ReplicaGroup(
name="reserved-group",
count=Range[int](min=1, max=1),
reservation="my-reservation-id",
),
ReplicaGroup(
count=Range[int](min=1, max=1),
name="unreserved-group",
),
],
),
)
body = {"run_spec": json.loads(run_spec.json())}

with patch("dstack._internal.server.services.backends.get_project_backends") as m:
aws_backend_mock = Mock()
aws_backend_mock.TYPE = BackendType.AWS
aws_backend_mock.compute.return_value.get_offers.return_value = [
get_instance_offer_with_availability(backend=BackendType.AWS, price=2)
]
verda_backend_mock = Mock()
verda_backend_mock.TYPE = BackendType.VERDA
verda_backend_mock.compute.return_value.get_offers.return_value = [
get_instance_offer_with_availability(backend=BackendType.VERDA, price=1)
]
m.return_value = [aws_backend_mock, verda_backend_mock]

response = await client.post(
f"/api/project/{project.name}/runs/get_plan",
headers=get_auth_headers(user.token),
json=body,
)
assert response.status_code == 200, response.json()
reserved_job_plan, unreserved_job_plan = response.json()["job_plans"]

# Verda offer not included for `reserved-group`, since Verda does not support reservations
assert reserved_job_plan["offers"][0]["backend"] == "aws"
assert len(reserved_job_plan["offers"]) == 1

assert unreserved_job_plan["offers"][0]["backend"] == "verda"
assert unreserved_job_plan["offers"][1]["backend"] == "aws"
assert len(unreserved_job_plan["offers"]) == 2

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_run_plan_docker_true(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ReplicaGroup,
ServiceConfiguration,
)
from dstack._internal.core.models.profiles import SpotPolicy
from dstack._internal.core.models.resources import Range
from dstack._internal.core.models.services import OpenAIChatModel
from dstack._internal.server.services.docker import ImageConfig
Expand Down Expand Up @@ -118,7 +119,7 @@ def _make_run_spec(replicas, **service_kwargs):
@pytest.mark.usefixtures("image_config_mock")
class TestPerGroupOverrides:
"""Verifies that ServiceJobConfigurator picks up per-replica-group
image-source fields (image, docker, python, nvcc, privileged)."""
image-source fields (image, docker, python, nvcc, privileged, etc)."""

async def test_image_name_uses_group_image(self):
run_spec = _make_run_spec(
Expand Down Expand Up @@ -331,3 +332,103 @@ async def test_user_does_not_lookup_for_group_docker(self, monkeypatch: pytest.M
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
await configurator._user()
mock_get_image_config.assert_not_called()

async def test_spot_policy_uses_group_value(self):
run_spec = _make_run_spec(
replicas=[
ReplicaGroup(
name="a",
count=Range(min=1, max=1),
commands=["x"],
spot_policy=SpotPolicy.SPOT,
)
],
)
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
assert configurator._spot_policy() == SpotPolicy.SPOT

async def test_spot_policy_defaults_to_ondemand_when_group_unset(self):
run_spec = _make_run_spec(
replicas=[
ReplicaGroup(
name="a",
count=Range(min=1, max=1),
commands=["x"],
)
],
)
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
assert configurator._spot_policy() == SpotPolicy.ONDEMAND

async def test_different_groups_different_spot_policies(self):
run_spec = _make_run_spec(
replicas=[
ReplicaGroup(
name="spot",
count=Range(min=1, max=1),
commands=["x"],
spot_policy=SpotPolicy.SPOT,
),
ReplicaGroup(
name="od",
count=Range(min=1, max=1),
commands=["y"],
spot_policy=SpotPolicy.ONDEMAND,
),
],
)
assert (
ServiceJobConfigurator(run_spec, replica_group_name="spot")._spot_policy()
== SpotPolicy.SPOT
)
assert (
ServiceJobConfigurator(run_spec, replica_group_name="od")._spot_policy()
== SpotPolicy.ONDEMAND
)

async def test_reservation_uses_group_value(self):
run_spec = _make_run_spec(
replicas=[
ReplicaGroup(
name="a",
count=Range(min=1, max=1),
commands=["x"],
reservation="my-reservation",
)
],
)
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
assert configurator._reservation() == "my-reservation"

async def test_reservation_defaults_to_none_when_group_unset(self):
run_spec = _make_run_spec(
replicas=[
ReplicaGroup(
name="a",
count=Range(min=1, max=1),
commands=["x"],
)
],
)
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
assert configurator._reservation() is None

async def test_different_groups_different_reservations(self):
run_spec = _make_run_spec(
replicas=[
ReplicaGroup(
name="a",
count=Range(min=1, max=1),
commands=["x"],
reservation="res-a",
),
ReplicaGroup(
name="b",
count=Range(min=1, max=1),
commands=["y"],
reservation="res-b",
),
],
)
assert ServiceJobConfigurator(run_spec, replica_group_name="a")._reservation() == "res-a"
assert ServiceJobConfigurator(run_spec, replica_group_name="b")._reservation() == "res-b"
Loading