Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"python-multipart>=0.0.16",
"filelock",
"psutil",
"gpuhunt==0.1.24",
"gpuhunt==0.1.25",
"argcomplete>=3.5.0",
"ignore-python>=0.2.0",
"orjson",
Expand Down
26 changes: 26 additions & 0 deletions src/dstack/_internal/core/backends/jarvislabs/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,32 @@ def add_ssh_key(self, public_key: str, key_name: str) -> None:
)
_raise_if_unsuccessful(resp, "Failed to add JarvisLabs SSH key")

def create_ssh_key(self, public_key: str, key_name: str) -> str:
self.add_ssh_key(public_key=public_key, key_name=key_name)
key_id = self.find_ssh_key_id(public_key=public_key, key_name=key_name)
if key_id is None:
raise BackendError("Failed to find created JarvisLabs SSH key")
return key_id

def find_ssh_key_id(self, public_key: str, key_name: str) -> Optional[str]:
normalized_key = _normalize_public_key(public_key)
for ssh_key in self.list_ssh_keys():
if str(ssh_key.get("key_name", "")) != key_name:
continue
if _normalize_public_key(str(ssh_key.get("ssh_key", ""))) != normalized_key:
continue
key_id = ssh_key.get("key_id")
if key_id is not None:
return str(key_id)
return None

def delete_ssh_key(self, key_id: str) -> None:
try:
resp = self._make_request("DELETE", f"ssh/{key_id}")
except JarvisLabsNotFoundError:
return
_raise_if_unsuccessful(resp, "Failed to delete JarvisLabs SSH key")

def add_ssh_key_if_needed(self, public_key: str) -> None:
normalized_key = _normalize_public_key(public_key)
for ssh_key in self.list_ssh_keys():
Expand Down
69 changes: 63 additions & 6 deletions src/dstack/_internal/core/backends/jarvislabs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import subprocess
import tempfile
from collections.abc import Iterable
from typing import List, Optional
from typing import List, Optional, cast

import gpuhunt
from gpuhunt.providers.jarvislabs import JarvisLabsProvider
from typing_extensions import NotRequired, TypedDict

from dstack._internal.core.backends.base.backend import Compute
from dstack._internal.core.backends.base.compute import (
Expand All @@ -25,6 +26,7 @@
from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig
from dstack._internal.core.errors import ProvisioningError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
Expand All @@ -47,6 +49,22 @@
SSH_LAUNCH_TIMEOUT_SECONDS = 60


class JarvisLabsOfferBackendData(TypedDict):
# Set by gpuhunt when normalized GPU identity differs from the JarvisLabs VM
# create token, e.g. "RTX-PRO6000" normalized to "RTXPRO6000".
gpu_type: NotRequired[str]


class JarvisLabsInstanceBackendData(CoreModel):
ssh_key_ids: Optional[List[str]] = None

@classmethod
def load(cls, raw: Optional[str]) -> "JarvisLabsInstanceBackendData":
if raw is None:
return cls()
return cls.__response__.parse_raw(raw)


class JarvisLabsCompute(
ComputeWithAllOffersCached,
ComputeWithCreateInstanceSupport,
Expand Down Expand Up @@ -85,9 +103,19 @@ def create_instance(
instance_name = generate_unique_instance_name(
instance_config, max_length=MAX_INSTANCE_NAME_LEN
)
self.api_client.add_ssh_key_if_needed(instance_config.ssh_keys[0].public)
ssh_key_ids: List[str] = []
instance_id = None
try:
# TODO: JarvisLabs has a default 10 SSH key limit. Consider project-level
# key reuse if per-instance keys become a bottleneck.
for idx, ssh_public_key in enumerate(instance_config.get_public_keys()):
ssh_key_ids.append(
_create_ssh_key(
client=self.api_client,
name=f"{instance_name}-{idx}.key",
public_key=ssh_public_key,
)
)
if instance_offer.instance.resources.gpus:
instance_id = self.api_client.create_gpu_vm(
gpu_type=_get_jarvislabs_gpu_type(instance_offer),
Expand Down Expand Up @@ -117,6 +145,13 @@ def create_instance(
logger.exception(
"Could not destroy failed JarvisLabs instance %s", instance_id
)
try:
_delete_ssh_keys(self.api_client, ssh_key_ids)
except Exception:
logger.exception(
"Could not delete JarvisLabs SSH keys %s after provisioning failure",
ssh_key_ids,
)
raise
return JobProvisioningData(
backend=instance_offer.backend,
Expand All @@ -130,7 +165,7 @@ def create_instance(
ssh_port=22,
dockerized=True,
ssh_proxy=None,
backend_data=None,
backend_data=JarvisLabsInstanceBackendData(ssh_key_ids=ssh_key_ids).json(),
)

def update_provisioning_data(
Expand Down Expand Up @@ -172,17 +207,39 @@ def update_provisioning_data(
def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
):
backend_data_parsed = JarvisLabsInstanceBackendData.load(backend_data)
self.api_client.destroy_instance(machine_id=instance_id, region=region)
_delete_ssh_keys(self.api_client, backend_data_parsed.ssh_key_ids)


def _create_ssh_key(client: JarvisLabsAPIClient, name: str, public_key: str) -> str:
return client.create_ssh_key(public_key=public_key, key_name=name)


def _delete_ssh_keys(client: JarvisLabsAPIClient, ssh_key_ids: Optional[List[str]]) -> None:
if not ssh_key_ids:
return
for ssh_key_id in ssh_key_ids:
client.delete_ssh_key(ssh_key_id)


def _get_jarvislabs_gpu_type(instance_offer: InstanceOfferWithAvailability) -> str:
gpu_type = _get_jarvislabs_gpu_type_from_backend_data(instance_offer.backend_data)
if gpu_type is not None:
return gpu_type

gpu = instance_offer.instance.resources.gpus[0]
memory_gb = round(gpu.memory_mib / 1024)
if gpu.name == "A100" and memory_gb == 80:
return "A100-80GB"
return gpu.name


def _get_jarvislabs_gpu_type_from_backend_data(backend_data: dict) -> Optional[str]:
offer_backend_data = cast(JarvisLabsOfferBackendData, backend_data)
gpu_type = offer_backend_data.get("gpu_type")
if not isinstance(gpu_type, str) or not gpu_type:
return None
return gpu_type


def _get_disk_size_gb(instance_offer: InstanceOfferWithAvailability) -> int:
disk_size_gb = round(instance_offer.instance.resources.disk.size_mib / 1024)
return max(round(MIN_DISK_SIZE), disk_size_gb)
Expand Down
87 changes: 79 additions & 8 deletions src/tests/_internal/core/backends/jarvislabs/test_api_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import requests
from gpuhunt.providers.jarvislabs import API_URL

from dstack._internal.core.backends.jarvislabs.api_client import (
JarvisLabsAPIClient,
Expand All @@ -9,21 +10,21 @@


def test_validate_api_key_returns_false_on_unauthorized(requests_mock):
requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", status_code=401)
requests_mock.get(f"{API_URL}/users/user_info", status_code=401)

assert JarvisLabsAPIClient("bad").validate_api_key() is False


def test_get_user_info_raises_invalid_credentials_on_forbidden(requests_mock):
requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", status_code=403)
requests_mock.get(f"{API_URL}/users/user_info", status_code=403)

with pytest.raises(BackendInvalidCredentialsError):
JarvisLabsAPIClient("bad").get_user_info()


def test_make_request_wraps_request_errors(requests_mock):
requests_mock.get(
"https://backendprod.jarvislabs.net/users/user_info",
f"{API_URL}/users/user_info",
exc=requests.ConnectTimeout("timed out"),
)

Expand All @@ -32,7 +33,7 @@ def test_make_request_wraps_request_errors(requests_mock):


def test_get_user_info_rejects_non_json_success_response(requests_mock):
requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", text="ok")
requests_mock.get(f"{API_URL}/users/user_info", text="ok")

with pytest.raises(BackendError, match="Unexpected non-JSON JarvisLabs response"):
JarvisLabsAPIClient("token").get_user_info()
Expand All @@ -41,7 +42,7 @@ def test_get_user_info_rejects_non_json_success_response(requests_mock):
def test_add_ssh_key_if_needed_reuses_existing_key(requests_mock):
public_key = "ssh-rsa AAAA test-comment"
requests_mock.get(
"https://backendprod.jarvislabs.net/ssh/",
f"{API_URL}/ssh/",
json=[{"ssh_key": "ssh-rsa AAAA another-comment", "key_name": "existing"}],
)

Expand All @@ -52,8 +53,8 @@ def test_add_ssh_key_if_needed_reuses_existing_key(requests_mock):

def test_add_ssh_key_if_needed_adds_missing_key(requests_mock):
public_key = "ssh-rsa AAAA test-comment"
requests_mock.get("https://backendprod.jarvislabs.net/ssh/", json=[])
requests_mock.post("https://backendprod.jarvislabs.net/ssh/", json={"success": True})
requests_mock.get(f"{API_URL}/ssh/", json=[])
requests_mock.post(f"{API_URL}/ssh/", json={"success": True})

JarvisLabsAPIClient("token").add_ssh_key_if_needed(public_key)

Expand All @@ -63,6 +64,57 @@ def test_add_ssh_key_if_needed_adds_missing_key(requests_mock):
}


def test_create_ssh_key_adds_key_and_returns_created_key_id(requests_mock):
public_key = "ssh-rsa AAAA test-comment"
requests_mock.post(f"{API_URL}/ssh/", json={"success": True})
requests_mock.get(
f"{API_URL}/ssh/",
json=[
{
"ssh_key": "ssh-rsa AAAA another-comment",
"key_name": "dstack-test-0.key",
"key_id": "key-id",
}
],
)

key_id = JarvisLabsAPIClient("token").create_ssh_key(
public_key=public_key,
key_name="dstack-test-0.key",
)

assert key_id == "key-id"
assert requests_mock.request_history[0].json() == {
"ssh_key": public_key,
"key_name": "dstack-test-0.key",
}


def test_create_ssh_key_raises_if_created_key_id_is_missing(requests_mock):
requests_mock.post(f"{API_URL}/ssh/", json={"success": True})
requests_mock.get(f"{API_URL}/ssh/", json=[])

with pytest.raises(BackendError, match="Failed to find created JarvisLabs SSH key"):
JarvisLabsAPIClient("token").create_ssh_key(
public_key="ssh-rsa AAAA test-comment",
key_name="dstack-test-0.key",
)


def test_delete_ssh_key_deletes_key(requests_mock):
requests_mock.delete(f"{API_URL}/ssh/key-id", json={"success": True})

JarvisLabsAPIClient("token").delete_ssh_key("key-id")

assert requests_mock.last_request.method == "DELETE"


def test_delete_ssh_key_ignores_missing_key(requests_mock):
requests_mock.delete(f"{API_URL}/ssh/key-id", status_code=404, json={"detail": "not found"})

JarvisLabsAPIClient("token").delete_ssh_key("key-id")


def test_create_gpu_vm_posts_to_regional_vm_endpoint(requests_mock):
requests_mock.post(
"https://backendn.jarvislabs.net/templates/vm/create",
Expand Down Expand Up @@ -97,6 +149,25 @@ def test_create_gpu_vm_posts_to_regional_vm_endpoint(requests_mock):
}


def test_create_gpu_vm_posts_chennai_region_to_chennai_endpoint(requests_mock):
requests_mock.post(
"https://backendc.jarvislabs.net/templates/vm/create",
json={"machine_id": 123},
)

JarvisLabsAPIClient("token").create_gpu_vm(
gpu_type="RTX-PRO6000",
num_gpus=1,
is_spot=False,
storage=100,
region="india-chennai-01",
name="dstack-test",
)

assert requests_mock.last_request.json()["gpu_type"] == "RTX-PRO6000"
assert requests_mock.last_request.json()["region"] == "india-chennai-01"


def test_create_gpu_vm_rejects_unsupported_region(requests_mock):
with pytest.raises(BackendError, match="Unsupported JarvisLabs region"):
JarvisLabsAPIClient("token").create_gpu_vm(
Expand Down Expand Up @@ -158,7 +229,7 @@ def test_create_cpu_vm_posts_to_regional_cpu_vm_endpoint(requests_mock):

def test_destroy_instance_uses_cpu_vm_endpoint_for_cpu_vm(requests_mock):
requests_mock.get(
"https://backendprod.jarvislabs.net/users/fetch/456",
f"{API_URL}/users/fetch/456",
json={
"success": True,
"instance": {
Expand Down
Loading
Loading