diff --git a/pyproject.toml b/pyproject.toml index 2a4e8620d..f13d7da2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "python-multipart>=0.0.16", "filelock", "psutil", - "gpuhunt==0.1.24", + "gpuhunt==0.1.25", "argcomplete>=3.5.0", "ignore-python>=0.2.0", "orjson", diff --git a/src/dstack/_internal/core/backends/jarvislabs/api_client.py b/src/dstack/_internal/core/backends/jarvislabs/api_client.py index e1a2a022e..0a9a3b68d 100644 --- a/src/dstack/_internal/core/backends/jarvislabs/api_client.py +++ b/src/dstack/_internal/core/backends/jarvislabs/api_client.py @@ -51,6 +51,32 @@ def add_ssh_key(self, public_key: str, key_name: str) -> None: ) _raise_if_unsuccessful(resp, "Failed to add JarvisLabs SSH key") + def create_ssh_key(self, public_key: str, key_name: str) -> str: + self.add_ssh_key(public_key=public_key, key_name=key_name) + key_id = self.find_ssh_key_id(public_key=public_key, key_name=key_name) + if key_id is None: + raise BackendError("Failed to find created JarvisLabs SSH key") + return key_id + + def find_ssh_key_id(self, public_key: str, key_name: str) -> Optional[str]: + normalized_key = _normalize_public_key(public_key) + for ssh_key in self.list_ssh_keys(): + if str(ssh_key.get("key_name", "")) != key_name: + continue + if _normalize_public_key(str(ssh_key.get("ssh_key", ""))) != normalized_key: + continue + key_id = ssh_key.get("key_id") + if key_id is not None: + return str(key_id) + return None + + def delete_ssh_key(self, key_id: str) -> None: + try: + resp = self._make_request("DELETE", f"ssh/{key_id}") + except JarvisLabsNotFoundError: + return + _raise_if_unsuccessful(resp, "Failed to delete JarvisLabs SSH key") + def add_ssh_key_if_needed(self, public_key: str) -> None: normalized_key = _normalize_public_key(public_key) for ssh_key in self.list_ssh_keys(): diff --git a/src/dstack/_internal/core/backends/jarvislabs/compute.py b/src/dstack/_internal/core/backends/jarvislabs/compute.py index 2e79902b4..24e13f850 100644 --- a/src/dstack/_internal/core/backends/jarvislabs/compute.py +++ b/src/dstack/_internal/core/backends/jarvislabs/compute.py @@ -2,10 +2,11 @@ import subprocess import tempfile from collections.abc import Iterable -from typing import List, Optional +from typing import List, Optional, cast import gpuhunt from gpuhunt.providers.jarvislabs import JarvisLabsProvider +from typing_extensions import NotRequired, TypedDict from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( @@ -25,6 +26,7 @@ from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig from dstack._internal.core.errors import ProvisioningError from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceConfiguration, @@ -47,6 +49,22 @@ SSH_LAUNCH_TIMEOUT_SECONDS = 60 +class JarvisLabsOfferBackendData(TypedDict): + # Set by gpuhunt when normalized GPU identity differs from the JarvisLabs VM + # create token, e.g. "RTX-PRO6000" normalized to "RTXPRO6000". + gpu_type: NotRequired[str] + + +class JarvisLabsInstanceBackendData(CoreModel): + ssh_key_ids: Optional[List[str]] = None + + @classmethod + def load(cls, raw: Optional[str]) -> "JarvisLabsInstanceBackendData": + if raw is None: + return cls() + return cls.__response__.parse_raw(raw) + + class JarvisLabsCompute( ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, @@ -85,9 +103,19 @@ def create_instance( instance_name = generate_unique_instance_name( instance_config, max_length=MAX_INSTANCE_NAME_LEN ) - self.api_client.add_ssh_key_if_needed(instance_config.ssh_keys[0].public) + ssh_key_ids: List[str] = [] instance_id = None try: + # TODO: JarvisLabs has a default 10 SSH key limit. Consider project-level + # key reuse if per-instance keys become a bottleneck. + for idx, ssh_public_key in enumerate(instance_config.get_public_keys()): + ssh_key_ids.append( + _create_ssh_key( + client=self.api_client, + name=f"{instance_name}-{idx}.key", + public_key=ssh_public_key, + ) + ) if instance_offer.instance.resources.gpus: instance_id = self.api_client.create_gpu_vm( gpu_type=_get_jarvislabs_gpu_type(instance_offer), @@ -117,6 +145,13 @@ def create_instance( logger.exception( "Could not destroy failed JarvisLabs instance %s", instance_id ) + try: + _delete_ssh_keys(self.api_client, ssh_key_ids) + except Exception: + logger.exception( + "Could not delete JarvisLabs SSH keys %s after provisioning failure", + ssh_key_ids, + ) raise return JobProvisioningData( backend=instance_offer.backend, @@ -130,7 +165,7 @@ def create_instance( ssh_port=22, dockerized=True, ssh_proxy=None, - backend_data=None, + backend_data=JarvisLabsInstanceBackendData(ssh_key_ids=ssh_key_ids).json(), ) def update_provisioning_data( @@ -172,17 +207,39 @@ def update_provisioning_data( def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): + backend_data_parsed = JarvisLabsInstanceBackendData.load(backend_data) self.api_client.destroy_instance(machine_id=instance_id, region=region) + _delete_ssh_keys(self.api_client, backend_data_parsed.ssh_key_ids) + + +def _create_ssh_key(client: JarvisLabsAPIClient, name: str, public_key: str) -> str: + return client.create_ssh_key(public_key=public_key, key_name=name) + + +def _delete_ssh_keys(client: JarvisLabsAPIClient, ssh_key_ids: Optional[List[str]]) -> None: + if not ssh_key_ids: + return + for ssh_key_id in ssh_key_ids: + client.delete_ssh_key(ssh_key_id) def _get_jarvislabs_gpu_type(instance_offer: InstanceOfferWithAvailability) -> str: + gpu_type = _get_jarvislabs_gpu_type_from_backend_data(instance_offer.backend_data) + if gpu_type is not None: + return gpu_type + gpu = instance_offer.instance.resources.gpus[0] - memory_gb = round(gpu.memory_mib / 1024) - if gpu.name == "A100" and memory_gb == 80: - return "A100-80GB" return gpu.name +def _get_jarvislabs_gpu_type_from_backend_data(backend_data: dict) -> Optional[str]: + offer_backend_data = cast(JarvisLabsOfferBackendData, backend_data) + gpu_type = offer_backend_data.get("gpu_type") + if not isinstance(gpu_type, str) or not gpu_type: + return None + return gpu_type + + def _get_disk_size_gb(instance_offer: InstanceOfferWithAvailability) -> int: disk_size_gb = round(instance_offer.instance.resources.disk.size_mib / 1024) return max(round(MIN_DISK_SIZE), disk_size_gb) diff --git a/src/tests/_internal/core/backends/jarvislabs/test_api_client.py b/src/tests/_internal/core/backends/jarvislabs/test_api_client.py index 09980d6ed..9fdf48724 100644 --- a/src/tests/_internal/core/backends/jarvislabs/test_api_client.py +++ b/src/tests/_internal/core/backends/jarvislabs/test_api_client.py @@ -1,5 +1,6 @@ import pytest import requests +from gpuhunt.providers.jarvislabs import API_URL from dstack._internal.core.backends.jarvislabs.api_client import ( JarvisLabsAPIClient, @@ -9,13 +10,13 @@ def test_validate_api_key_returns_false_on_unauthorized(requests_mock): - requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", status_code=401) + requests_mock.get(f"{API_URL}/users/user_info", status_code=401) assert JarvisLabsAPIClient("bad").validate_api_key() is False def test_get_user_info_raises_invalid_credentials_on_forbidden(requests_mock): - requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", status_code=403) + requests_mock.get(f"{API_URL}/users/user_info", status_code=403) with pytest.raises(BackendInvalidCredentialsError): JarvisLabsAPIClient("bad").get_user_info() @@ -23,7 +24,7 @@ def test_get_user_info_raises_invalid_credentials_on_forbidden(requests_mock): def test_make_request_wraps_request_errors(requests_mock): requests_mock.get( - "https://backendprod.jarvislabs.net/users/user_info", + f"{API_URL}/users/user_info", exc=requests.ConnectTimeout("timed out"), ) @@ -32,7 +33,7 @@ def test_make_request_wraps_request_errors(requests_mock): def test_get_user_info_rejects_non_json_success_response(requests_mock): - requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", text="ok") + requests_mock.get(f"{API_URL}/users/user_info", text="ok") with pytest.raises(BackendError, match="Unexpected non-JSON JarvisLabs response"): JarvisLabsAPIClient("token").get_user_info() @@ -41,7 +42,7 @@ def test_get_user_info_rejects_non_json_success_response(requests_mock): def test_add_ssh_key_if_needed_reuses_existing_key(requests_mock): public_key = "ssh-rsa AAAA test-comment" requests_mock.get( - "https://backendprod.jarvislabs.net/ssh/", + f"{API_URL}/ssh/", json=[{"ssh_key": "ssh-rsa AAAA another-comment", "key_name": "existing"}], ) @@ -52,8 +53,8 @@ def test_add_ssh_key_if_needed_reuses_existing_key(requests_mock): def test_add_ssh_key_if_needed_adds_missing_key(requests_mock): public_key = "ssh-rsa AAAA test-comment" - requests_mock.get("https://backendprod.jarvislabs.net/ssh/", json=[]) - requests_mock.post("https://backendprod.jarvislabs.net/ssh/", json={"success": True}) + requests_mock.get(f"{API_URL}/ssh/", json=[]) + requests_mock.post(f"{API_URL}/ssh/", json={"success": True}) JarvisLabsAPIClient("token").add_ssh_key_if_needed(public_key) @@ -63,6 +64,57 @@ def test_add_ssh_key_if_needed_adds_missing_key(requests_mock): } +def test_create_ssh_key_adds_key_and_returns_created_key_id(requests_mock): + public_key = "ssh-rsa AAAA test-comment" + requests_mock.post(f"{API_URL}/ssh/", json={"success": True}) + requests_mock.get( + f"{API_URL}/ssh/", + json=[ + { + "ssh_key": "ssh-rsa AAAA another-comment", + "key_name": "dstack-test-0.key", + "key_id": "key-id", + } + ], + ) + + key_id = JarvisLabsAPIClient("token").create_ssh_key( + public_key=public_key, + key_name="dstack-test-0.key", + ) + + assert key_id == "key-id" + assert requests_mock.request_history[0].json() == { + "ssh_key": public_key, + "key_name": "dstack-test-0.key", + } + + +def test_create_ssh_key_raises_if_created_key_id_is_missing(requests_mock): + requests_mock.post(f"{API_URL}/ssh/", json={"success": True}) + requests_mock.get(f"{API_URL}/ssh/", json=[]) + + with pytest.raises(BackendError, match="Failed to find created JarvisLabs SSH key"): + JarvisLabsAPIClient("token").create_ssh_key( + public_key="ssh-rsa AAAA test-comment", + key_name="dstack-test-0.key", + ) + + +def test_delete_ssh_key_deletes_key(requests_mock): + requests_mock.delete(f"{API_URL}/ssh/key-id", json={"success": True}) + + JarvisLabsAPIClient("token").delete_ssh_key("key-id") + + assert requests_mock.last_request.method == "DELETE" + + +def test_delete_ssh_key_ignores_missing_key(requests_mock): + requests_mock.delete(f"{API_URL}/ssh/key-id", status_code=404, json={"detail": "not found"}) + + JarvisLabsAPIClient("token").delete_ssh_key("key-id") + + def test_create_gpu_vm_posts_to_regional_vm_endpoint(requests_mock): requests_mock.post( "https://backendn.jarvislabs.net/templates/vm/create", @@ -97,6 +149,25 @@ def test_create_gpu_vm_posts_to_regional_vm_endpoint(requests_mock): } +def test_create_gpu_vm_posts_chennai_region_to_chennai_endpoint(requests_mock): + requests_mock.post( + "https://backendc.jarvislabs.net/templates/vm/create", + json={"machine_id": 123}, + ) + + JarvisLabsAPIClient("token").create_gpu_vm( + gpu_type="RTX-PRO6000", + num_gpus=1, + is_spot=False, + storage=100, + region="india-chennai-01", + name="dstack-test", + ) + + assert requests_mock.last_request.json()["gpu_type"] == "RTX-PRO6000" + assert requests_mock.last_request.json()["region"] == "india-chennai-01" + + def test_create_gpu_vm_rejects_unsupported_region(requests_mock): with pytest.raises(BackendError, match="Unsupported JarvisLabs region"): JarvisLabsAPIClient("token").create_gpu_vm( @@ -158,7 +229,7 @@ def test_create_cpu_vm_posts_to_regional_cpu_vm_endpoint(requests_mock): def test_destroy_instance_uses_cpu_vm_endpoint_for_cpu_vm(requests_mock): requests_mock.get( - "https://backendprod.jarvislabs.net/users/fetch/456", + f"{API_URL}/users/fetch/456", json={ "success": True, "instance": { diff --git a/src/tests/_internal/core/backends/jarvislabs/test_compute.py b/src/tests/_internal/core/backends/jarvislabs/test_compute.py index 27be084e1..6ee60dfc2 100644 --- a/src/tests/_internal/core/backends/jarvislabs/test_compute.py +++ b/src/tests/_internal/core/backends/jarvislabs/test_compute.py @@ -1,16 +1,17 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import pytest from dstack._internal.core.backends.jarvislabs.compute import ( CONFIGURABLE_DISK_SIZE, JarvisLabsCompute, + JarvisLabsInstanceBackendData, _get_disk_size_gb, _get_jarvislabs_gpu_type, _get_ssh_username, ) from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig, JarvisLabsCreds -from dstack._internal.core.errors import NoCapacityError, ProvisioningError +from dstack._internal.core.errors import BackendError, NoCapacityError, ProvisioningError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( Disk, @@ -32,16 +33,17 @@ def _compute() -> JarvisLabsCompute: JarvisLabsConfig(creds=JarvisLabsCreds(api_key="test"), regions=["india-noida-01"]) ) compute.api_client = MagicMock() + compute.api_client.create_ssh_key.return_value = "ssh-key-id" compute.api_client.get_instance_status.return_value = {"status": "Running"} return compute -def _instance_config() -> InstanceConfiguration: +def _instance_config(ssh_keys: list[SSHKey] | None = None) -> InstanceConfiguration: return InstanceConfiguration( project_name="test-project", instance_name="jarvislabs-test", user="test-user", - ssh_keys=[SSHKey(public="ssh-rsa AAAA test")], + ssh_keys=ssh_keys or [SSHKey(public="ssh-rsa AAAA test")], ) @@ -51,6 +53,7 @@ def _gpu_offer( gpu_memory_mib: int = 80 * 1024, disk_size_mib: int = 250 * 1024, spot: bool = False, + backend_data: dict | None = None, ) -> InstanceOfferWithAvailability: return InstanceOfferWithAvailability( backend=BackendType.JARVISLABS, @@ -66,6 +69,7 @@ def _gpu_offer( ), region="india-noida-01", price=1.49, + backend_data=backend_data or {}, availability=InstanceAvailability.AVAILABLE, ) @@ -99,10 +103,32 @@ def _cpu_catalog_offer(*, disk_size_mib: int = 10 * 1024) -> InstanceOffer: ) -def test_get_jarvislabs_gpu_type_reconstructs_a100_80gb(): - assert _get_jarvislabs_gpu_type(_gpu_offer()) == "A100-80GB" - assert _get_jarvislabs_gpu_type(_gpu_offer(gpu_memory_mib=40 * 1024)) == "A100" +def test_get_jarvislabs_gpu_type_uses_backend_data_or_gpu_name(): + assert ( + _get_jarvislabs_gpu_type(_gpu_offer(backend_data={"gpu_type": "A100-80GB"})) == "A100-80GB" + ) + assert _get_jarvislabs_gpu_type(_gpu_offer()) == "A100" assert _get_jarvislabs_gpu_type(_gpu_offer(gpu_name="H100")) == "H100" + assert ( + _get_jarvislabs_gpu_type( + _gpu_offer( + gpu_name="RTXPRO6000", + gpu_memory_mib=96 * 1024, + backend_data={"gpu_type": "RTX-PRO6000"}, + ) + ) + == "RTX-PRO6000" + ) + + +def test_get_jarvislabs_gpu_type_prefers_backend_data(): + offer = _gpu_offer( + gpu_name="RTXPRO6000", + gpu_memory_mib=96 * 1024, + backend_data={"gpu_type": "RTX PRO 6000"}, + ) + + assert _get_jarvislabs_gpu_type(offer) == "RTX PRO 6000" def test_get_disk_size_gb_clamps_to_jarvislabs_vm_minimum(): @@ -145,7 +171,7 @@ def test_get_offers_reuses_all_offers_cache_and_modifies_disk_size(): compute.get_all_offers_with_availability.assert_called_once() -def test_create_gpu_instance_registers_ssh_key_and_creates_gpu_vm(): +def test_create_gpu_instance_creates_ssh_key_and_gpu_vm(): compute = _compute() compute.api_client.create_gpu_vm.return_value = "123" @@ -153,9 +179,14 @@ def test_create_gpu_instance_registers_ssh_key_and_creates_gpu_vm(): "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", return_value="dstack-test", ): - provisioning_data = compute.create_instance(_gpu_offer(), _instance_config(), None) + provisioning_data = compute.create_instance( + _gpu_offer(backend_data={"gpu_type": "A100-80GB"}), _instance_config(), None + ) - compute.api_client.add_ssh_key_if_needed.assert_called_once_with("ssh-rsa AAAA test") + compute.api_client.create_ssh_key.assert_called_once_with( + public_key="ssh-rsa AAAA test", + key_name="dstack-test-0.key", + ) compute.api_client.create_gpu_vm.assert_called_once_with( gpu_type="A100-80GB", num_gpus=1, @@ -167,7 +198,8 @@ def test_create_gpu_instance_registers_ssh_key_and_creates_gpu_vm(): assert provisioning_data.instance_id == "123" assert provisioning_data.username == "ubuntu" assert provisioning_data.dockerized is True - assert provisioning_data.backend_data is None + backend_data = JarvisLabsInstanceBackendData.load(provisioning_data.backend_data) + assert backend_data.ssh_key_ids == ["ssh-key-id"] compute.api_client.get_instance_status.assert_not_called() @@ -179,7 +211,11 @@ def test_create_gpu_instance_passes_spot_flag(): "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", return_value="dstack-test", ): - compute.create_instance(_gpu_offer(spot=True), _instance_config(), None) + compute.create_instance( + _gpu_offer(spot=True, backend_data={"gpu_type": "A100-80GB"}), + _instance_config(), + None, + ) compute.api_client.create_gpu_vm.assert_called_once_with( gpu_type="A100-80GB", @@ -191,7 +227,32 @@ def test_create_gpu_instance_passes_spot_flag(): ) -def test_create_cpu_instance_registers_ssh_key_and_creates_cpu_vm(): +def test_create_rtx_pro_6000_instance_uses_jarvislabs_gpu_type_from_backend_data(): + compute = _compute() + compute.api_client.create_gpu_vm.return_value = "123" + offer = _gpu_offer( + gpu_name="RTXPRO6000", + gpu_memory_mib=96 * 1024, + backend_data={"gpu_type": "RTX-PRO6000"}, + ) + + with patch( + "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", + return_value="dstack-test", + ): + compute.create_instance(offer, _instance_config(), None) + + compute.api_client.create_gpu_vm.assert_called_once_with( + gpu_type="RTX-PRO6000", + num_gpus=1, + is_spot=False, + storage=250, + region="india-noida-01", + name="dstack-test", + ) + + +def test_create_cpu_instance_creates_ssh_key_and_cpu_vm(): compute = _compute() compute.api_client.create_cpu_vm.return_value = "456" @@ -201,7 +262,10 @@ def test_create_cpu_instance_registers_ssh_key_and_creates_cpu_vm(): ): provisioning_data = compute.create_instance(_cpu_offer(), _instance_config(), None) - compute.api_client.add_ssh_key_if_needed.assert_called_once_with("ssh-rsa AAAA test") + compute.api_client.create_ssh_key.assert_called_once_with( + public_key="ssh-rsa AAAA test", + key_name="dstack-cpu-0.key", + ) compute.api_client.create_cpu_vm.assert_called_once_with( vcpus=4, ram_gb=16, @@ -210,7 +274,8 @@ def test_create_cpu_instance_registers_ssh_key_and_creates_cpu_vm(): name="dstack-cpu", ) assert provisioning_data.instance_id == "456" - assert provisioning_data.backend_data is None + backend_data = JarvisLabsInstanceBackendData.load(provisioning_data.backend_data) + assert backend_data.ssh_key_ids == ["ssh-key-id"] def test_update_provisioning_data_sets_hostname_and_starts_runner(): @@ -291,7 +356,7 @@ def test_get_ssh_username_parses_jarvislabs_ssh_command(): assert _get_ssh_username({}) == "ubuntu" -def test_terminate_instance_delegates_to_api_client(): +def test_terminate_instance_delegates_to_api_client_without_backend_data(): compute = _compute() compute.terminate_instance("123", "india-noida-01") @@ -300,9 +365,28 @@ def test_terminate_instance_delegates_to_api_client(): machine_id="123", region="india-noida-01", ) + compute.api_client.delete_ssh_key.assert_not_called() -def test_create_instance_propagates_create_failure_without_cleanup(): +def test_terminate_instance_deletes_created_ssh_keys(): + compute = _compute() + backend_data = JarvisLabsInstanceBackendData( + ssh_key_ids=["ssh-key-id-1", "ssh-key-id-2"] + ).json() + + compute.terminate_instance("123", "india-noida-01", backend_data) + + compute.api_client.destroy_instance.assert_called_once_with( + machine_id="123", + region="india-noida-01", + ) + assert compute.api_client.delete_ssh_key.call_args_list == [ + call("ssh-key-id-1"), + call("ssh-key-id-2"), + ] + + +def test_create_instance_cleans_up_ssh_key_on_create_failure(): compute = _compute() compute.api_client.create_gpu_vm.side_effect = NoCapacityError( "L4 not available at this moment, please try again later" @@ -316,6 +400,31 @@ def test_create_instance_propagates_create_failure_without_cleanup(): compute.create_instance(_gpu_offer(spot=True), _instance_config(), None) compute.api_client.destroy_instance.assert_not_called() + compute.api_client.delete_ssh_key.assert_called_once_with("ssh-key-id") + + +def test_create_instance_cleans_up_created_ssh_key_if_later_ssh_key_create_fails(): + compute = _compute() + compute.api_client.create_ssh_key.side_effect = [ + "ssh-key-id-1", + BackendError("ssh create failed"), + ] + instance_config = _instance_config( + ssh_keys=[ + SSHKey(public="ssh-rsa AAAA test-1"), + SSHKey(public="ssh-rsa BBBB test-2"), + ] + ) + + with patch( + "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", + return_value="dstack-test", + ): + with pytest.raises(BackendError, match="ssh create failed"): + compute.create_instance(_gpu_offer(), instance_config, None) + + compute.api_client.create_gpu_vm.assert_not_called() + compute.api_client.delete_ssh_key.assert_called_once_with("ssh-key-id-1") def test_update_provisioning_data_raises_provisioning_error_from_failed_capacity_status():