From b8af4acea20f16f7f14ee881bce00c906817e181 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 13 May 2026 15:43:06 +0200 Subject: [PATCH] Support global exports Users with the global admin role can mark any export as a global export. Global exports are automatically imported into all projects, and their imports cannot be deleted. ```shell $ dstack export create global-export --gateway shared-gateway --global NAME FLEETS GATEWAYS IMPORTERS global-export - shared-gateway * ``` Only promoting an export to global requires the global admin role. Regular project admins can add or remove resources, remove global status, or delete the export. --- mkdocs/docs/concepts/exports.md | 17 + src/dstack/_internal/cli/commands/export.py | 32 +- .../_internal/core/compatibility/exports.py | 6 + src/dstack/_internal/core/models/exports.py | 1 + src/dstack/_internal/server/const.py | 5 + ..._201cb7ccd0d3_add_exportmodel_is_global.py | 34 ++ src/dstack/_internal/server/models.py | 1 + .../_internal/server/routers/exports.py | 3 + .../_internal/server/schemas/exports.py | 3 + .../_internal/server/services/exports.py | 85 ++++- .../_internal/server/services/imports.py | 6 +- .../_internal/server/services/locking.py | 16 +- .../_internal/server/services/projects.py | 9 +- src/dstack/_internal/server/testing/common.py | 2 + src/dstack/api/server/_exports.py | 6 + .../_internal/server/routers/test_exports.py | 323 +++++++++++++++++- .../_internal/server/routers/test_imports.py | 34 ++ .../_internal/server/routers/test_projects.py | 40 +++ .../_internal/server/services/test_config.py | 36 +- 19 files changed, 644 insertions(+), 15 deletions(-) create mode 100644 src/dstack/_internal/server/const.py create mode 100644 src/dstack/_internal/server/migrations/versions/2026/05_13_0724_201cb7ccd0d3_add_exportmodel_is_global.py diff --git a/mkdocs/docs/concepts/exports.md b/mkdocs/docs/concepts/exports.md index d28aab2dea..86748367fd 100644 --- a/mkdocs/docs/concepts/exports.md +++ b/mkdocs/docs/concepts/exports.md @@ -114,6 +114,23 @@ Export my-export deleted Use `-y` to skip the confirmation prompt. +### Global exports + +Users with the global admin role can mark any export as a global export. Global exports are automatically imported into all projects, and their imports cannot be deleted. + +
+ +```shell +$ dstack export create global-export --gateway shared-gateway --global + NAME FLEETS GATEWAYS IMPORTERS + global-export - shared-gateway * + +``` + +Only promoting an export to global requires the global admin role. Regular project admins can add or remove resources, remove global status, or delete the export. + +
+ ## Access imported resources From the importer project's perspective, use `dstack import list` (or simply `dstack import`) to list all imports in the project — i.e., all exports from other projects that this project has been granted access to: diff --git a/src/dstack/_internal/cli/commands/export.py b/src/dstack/_internal/cli/commands/export.py index e8bf5db9ef..b21b58cfe0 100644 --- a/src/dstack/_internal/cli/commands/export.py +++ b/src/dstack/_internal/cli/commands/export.py @@ -50,6 +50,13 @@ def _register(self): help="Gateway name to export (can be specified multiple times)", default=[], ) + create_parser.add_argument( + "--global", + dest="is_global", + action="store_true", + help="Make this export global (automatically imported into all projects)", + default=False, + ) create_parser.set_defaults(subfunc=self._create) update_parser = subparsers.add_parser( @@ -101,6 +108,21 @@ def _register(self): help="Gateway name to remove (can be specified multiple times)", default=[], ) + global_group = update_parser.add_mutually_exclusive_group() + global_group.add_argument( + "--set-global", + dest="set_global", + action="store_true", + help="Make this export global (automatically imported into all projects)", + default=False, + ) + global_group.add_argument( + "--unset-global", + dest="unset_global", + action="store_true", + help="Remove the global flag from this export", + default=False, + ) update_parser.set_defaults(subfunc=self._update) delete_parser = subparsers.add_parser( @@ -128,6 +150,7 @@ def _create(self, args: argparse.Namespace): export = self.api.client.exports.create( project_name=self.api.project, name=args.name, + is_global=args.is_global, importer_projects=args.importers, exported_fleets=args.fleets, exported_gateways=args.gateways, @@ -139,6 +162,8 @@ def _update(self, args: argparse.Namespace): export = self.api.client.exports.update( project_name=self.api.project, name=args.name, + set_global=args.set_global, + unset_global=args.unset_global, add_importer_projects=args.add_importers, remove_importer_projects=args.remove_importers, add_exported_fleets=args.add_fleets, @@ -175,7 +200,12 @@ def print_exports_table(exports: list[Export]): if export.exported_gateways else "-" ) - importers = ", ".join([i.project_name for i in export.imports]) if export.imports else "-" + if export.is_global: + importers = "*" + else: + importers = ( + ", ".join([i.project_name for i in export.imports]) if export.imports else "-" + ) row = { "NAME": export.name, diff --git a/src/dstack/_internal/core/compatibility/exports.py b/src/dstack/_internal/core/compatibility/exports.py index 2b9e2c85ba..92f1d2dc6e 100644 --- a/src/dstack/_internal/core/compatibility/exports.py +++ b/src/dstack/_internal/core/compatibility/exports.py @@ -4,6 +4,8 @@ def get_create_export_excludes(request: CreateExportRequest) -> IncludeExcludeDictType: excludes: IncludeExcludeDictType = {} + if not request.is_global: + excludes["is_global"] = True if not request.exported_gateways: excludes["exported_gateways"] = True return excludes @@ -11,6 +13,10 @@ def get_create_export_excludes(request: CreateExportRequest) -> IncludeExcludeDi def get_update_export_excludes(request: UpdateExportRequest) -> IncludeExcludeDictType: excludes: IncludeExcludeDictType = {} + if not request.set_global: + excludes["set_global"] = True + if not request.unset_global: + excludes["unset_global"] = True if not request.add_exported_gateways: excludes["add_exported_gateways"] = True if not request.remove_exported_gateways: diff --git a/src/dstack/_internal/core/models/exports.py b/src/dstack/_internal/core/models/exports.py index d1f9ed6c61..eafe38b478 100644 --- a/src/dstack/_internal/core/models/exports.py +++ b/src/dstack/_internal/core/models/exports.py @@ -20,6 +20,7 @@ class ExportedGateway(CoreModel): class Export(CoreModel): id: uuid.UUID name: str + is_global: bool = False imports: list[ExportImport] exported_fleets: list[ExportedFleet] exported_gateways: list[ExportedGateway] = [] diff --git a/src/dstack/_internal/server/const.py b/src/dstack/_internal/server/const.py new file mode 100644 index 0000000000..0ebf4643cb --- /dev/null +++ b/src/dstack/_internal/server/const.py @@ -0,0 +1,5 @@ +GLOBAL_EXPORTS_LOCK_NAMESPACE = "global_exports" +""" +Lock used to avoid race conditions between promoting an export to global and creating new projects. +Ensures that all projects always import all global exports. +""" diff --git a/src/dstack/_internal/server/migrations/versions/2026/05_13_0724_201cb7ccd0d3_add_exportmodel_is_global.py b/src/dstack/_internal/server/migrations/versions/2026/05_13_0724_201cb7ccd0d3_add_exportmodel_is_global.py new file mode 100644 index 0000000000..677dac54a0 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/05_13_0724_201cb7ccd0d3_add_exportmodel_is_global.py @@ -0,0 +1,34 @@ +"""Add ExportModel.is_global + +Revision ID: 201cb7ccd0d3 +Revises: 205690dfeec2 +Create Date: 2026-05-13 07:24:06.321892+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "201cb7ccd0d3" +down_revision = "205690dfeec2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("exports", schema=None) as batch_op: + batch_op.add_column( + sa.Column("is_global", sa.Boolean(), server_default=sa.false(), nullable=False) + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("exports", schema=None) as batch_op: + batch_op.drop_column("is_global") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index d4976c59d5..d433244ea3 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -1139,6 +1139,7 @@ class ExportModel(BaseModel): ForeignKey("projects.id", ondelete="CASCADE"), index=True ) project: Mapped["ProjectModel"] = relationship() + is_global: Mapped[bool] = mapped_column(Boolean, default=False, server_default=false()) created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) imports: Mapped[List["ImportModel"]] = relationship( back_populates="export", diff --git a/src/dstack/_internal/server/routers/exports.py b/src/dstack/_internal/server/routers/exports.py index 1d682cbf3b..710fe17ce2 100644 --- a/src/dstack/_internal/server/routers/exports.py +++ b/src/dstack/_internal/server/routers/exports.py @@ -34,6 +34,7 @@ async def create_export( project=project, user=user, name=body.name, + is_global=body.is_global, importer_project_names=body.importer_projects, exported_fleet_names=body.exported_fleets, exported_gateway_names=body.exported_gateways, @@ -52,6 +53,8 @@ async def update_export( project=project, user=user, name=body.name, + set_global=body.set_global, + unset_global=body.unset_global, add_importer_project_names=body.add_importer_projects, remove_importer_project_names=body.remove_importer_projects, add_exported_fleet_names=body.add_exported_fleets, diff --git a/src/dstack/_internal/server/schemas/exports.py b/src/dstack/_internal/server/schemas/exports.py index 7f013c92ea..74828fb455 100644 --- a/src/dstack/_internal/server/schemas/exports.py +++ b/src/dstack/_internal/server/schemas/exports.py @@ -3,6 +3,7 @@ class CreateExportRequest(CoreModel): name: str + is_global: bool = False importer_projects: list[str] = [] exported_fleets: list[str] = [] exported_gateways: list[str] = [] @@ -10,6 +11,8 @@ class CreateExportRequest(CoreModel): class UpdateExportRequest(CoreModel): name: str + set_global: bool = False + unset_global: bool = False add_importer_projects: list[str] = [] remove_importer_projects: list[str] = [] add_exported_fleets: list[str] = [] diff --git a/src/dstack/_internal/server/services/exports.py b/src/dstack/_internal/server/services/exports.py index 3dc2a3b1c0..d22f38cc95 100644 --- a/src/dstack/_internal/server/services/exports.py +++ b/src/dstack/_internal/server/services/exports.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import selectinload from dstack._internal.core.errors import ( + ForbiddenError, ResourceExistsError, ResourceNotExistsError, ServerClientError, @@ -20,6 +21,7 @@ ) from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.services import validate_dstack_resource_name +from dstack._internal.server.const import GLOBAL_EXPORTS_LOCK_NAMESPACE from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( ExportedFleetModel, @@ -34,9 +36,14 @@ ) from dstack._internal.server.services.fleets import get_fleet_spec, list_project_fleet_models from dstack._internal.server.services.gateways import list_project_gateway_models -from dstack._internal.server.services.locking import get_locker, string_to_lock_id +from dstack._internal.server.services.locking import ( + advisory_lock_ctx, + get_locker, + string_to_lock_id, +) from dstack._internal.server.services.projects import ( get_user_project_role, + list_project_models, list_user_project_models, ) @@ -105,11 +112,17 @@ async def create_export( project: ProjectModel, user: UserModel, name: str, + is_global: bool, importer_project_names: list[str], exported_fleet_names: list[str], exported_gateway_names: list[str], ) -> Export: validate_dstack_resource_name(name) + if is_global and importer_project_names: + raise ServerClientError( + "Do not specify any importer projects when creating a global export." + " Global exports are automatically imported in all projects" + ) lock_namespace = f"export_names_{project.name}" if is_db_sqlite(): @@ -129,6 +142,7 @@ async def create_export( export = ExportModel( name=name, project=project, + is_global=False, imports=[], exported_fleets=[], exported_gateways=[], @@ -137,7 +151,14 @@ async def create_export( await add_exported_fleets(session, export, exported_fleet_names) await add_exported_gateways(session, export, exported_gateway_names) session.add(export) - await session.commit() + if is_global: + async with advisory_lock_ctx( + session, get_db().dialect_name, GLOBAL_EXPORTS_LOCK_NAMESPACE + ): + await set_as_global(session, export, user) + await session.commit() # commit before releasing the lock + else: + await session.commit() return export_model_to_export(export) @@ -146,6 +167,8 @@ async def update_export( project: ProjectModel, user: UserModel, name: str, + set_global: bool, + unset_global: bool, add_importer_project_names: list[str], remove_importer_project_names: list[str], add_exported_fleet_names: list[str], @@ -158,7 +181,9 @@ async def update_export( raise ResourceNotExistsError(f"Export {name!r} not found in project {project.name!r}") if ( - not add_importer_project_names + not set_global + and not unset_global + and not add_importer_project_names and not remove_importer_project_names and not add_exported_fleet_names and not remove_exported_fleet_names @@ -166,6 +191,14 @@ async def update_export( and not remove_exported_gateway_names ): raise ServerClientError("No changes specified") + if set_global and unset_global: + raise ServerClientError("Cannot set and unset global at the same time") + if (set_global or unset_global) and ( + add_importer_project_names or remove_importer_project_names + ): + raise ServerClientError( + "Cannot change global status and add/remove importers at the same time" + ) add_importer_project_names = list(map(str.lower, add_importer_project_names)) remove_importer_project_names = list(map(str.lower, remove_importer_project_names)) @@ -201,11 +234,48 @@ async def update_export( await remove_importer_projects(export, remove_importer_project_names) await remove_exported_fleets(export, remove_exported_fleet_names) await remove_exported_gateways(export, remove_exported_gateway_names) - - await session.commit() + if unset_global: + await unset_as_global(export) + if set_global: + async with advisory_lock_ctx( + session, get_db().dialect_name, GLOBAL_EXPORTS_LOCK_NAMESPACE + ): + await set_as_global(session, export, user) + await session.commit() # commit before releasing the lock + else: + await session.commit() return export_model_to_export(export) +async def set_as_global(session: AsyncSession, export: ExportModel, user: UserModel) -> None: + """ + **NOTE**: + Should be called with the `GLOBAL_EXPORTS_LOCK_NAMESPACE` lock acquired to prevent new + projects from being created while this export is being imported into existing ones. + """ + if export.is_global: + raise ServerClientError("The export is already global") + if user.global_role != GlobalRole.ADMIN: + raise ForbiddenError("Only global admins can make the export global") + all_projects = await list_project_models( + session, load_only_attrs=[ProjectModel.id, ProjectModel.name] + ) + already_importing = {imp.project_id for imp in export.imports} + for project in all_projects: + if project.id == export.project.id: + continue + if project.id in already_importing: + continue + export.imports.append(ImportModel(project=project)) + export.is_global = True + + +async def unset_as_global(export: ExportModel) -> None: + if not export.is_global: + raise ServerClientError("The export is already not global") + export.is_global = False + + async def add_importer_projects( session: AsyncSession, user: UserModel, export: ExportModel, names: list[str] ) -> None: @@ -270,6 +340,10 @@ async def add_exported_fleets( async def remove_importer_projects(export: ExportModel, names: list[str]) -> None: + if not names: + return + if export.is_global: + raise ServerClientError("Cannot remove importers from a global export") names = list(map(str.lower, names)) if len(names) != len(set(names)): raise ServerClientError("Some importer projects are listed for removal more than once") @@ -364,6 +438,7 @@ def export_model_to_export(export_model: ExportModel) -> Export: return Export( id=export_model.id, name=export_model.name, + is_global=export_model.is_global, imports=[ ExportImport( project_name=import_model.project.name, diff --git a/src/dstack/_internal/server/services/imports.py b/src/dstack/_internal/server/services/imports.py index cee432764a..6ee2e93201 100644 --- a/src/dstack/_internal/server/services/imports.py +++ b/src/dstack/_internal/server/services/imports.py @@ -2,7 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, selectinload -from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.core.errors import ResourceNotExistsError, ServerClientError from dstack._internal.core.models.imports import ( Import, ImportExport, @@ -70,6 +70,10 @@ async def delete_import( raise not_found_error if project.name.lower() not in {imp.project.name.lower() for imp in export.imports}: raise not_found_error + if export.is_global: + raise ServerClientError( + f"'{export_project_name}/{export_name}' is a global export, cannot stop importing" + ) export.imports = [ imp for imp in export.imports if imp.project.name.lower() != project.name.lower() ] diff --git a/src/dstack/_internal/server/services/locking.py b/src/dstack/_internal/server/services/locking.py index 71a4aa7bfe..4656d6e047 100644 --- a/src/dstack/_internal/server/services/locking.py +++ b/src/dstack/_internal/server/services/locking.py @@ -126,13 +126,19 @@ def string_to_lock_id(s: str) -> int: async def advisory_lock_ctx( bind: Union[AsyncConnection, AsyncSession], dialect_name: str, resource: str ): + """ + Take a global lock on `resource` across all dstack server replicas. + In-memory lock for SQLite, advisory lock for Postgres. + """ if dialect_name == "postgresql": await bind.execute(select(func.pg_advisory_lock(string_to_lock_id(resource)))) - try: - yield - finally: - if dialect_name == "postgresql": - await bind.execute(select(func.pg_advisory_unlock(string_to_lock_id(resource)))) + lock, _ = get_locker(dialect_name).get_lockset(resource) + async with lock: + try: + yield + finally: + if dialect_name == "postgresql": + await bind.execute(select(func.pg_advisory_unlock(string_to_lock_id(resource)))) @asynccontextmanager diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 499d6c039e..8d8e1a890e 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -26,6 +26,8 @@ ) from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.const import GLOBAL_EXPORTS_LOCK_NAMESPACE +from dstack._internal.server.db import get_db from dstack._internal.server.models import ( ExportModel, FleetModel, @@ -42,6 +44,7 @@ from dstack._internal.server.services.backends import ( get_backend_config_without_creds_from_backend_model, ) +from dstack._internal.server.services.locking import advisory_lock_ctx from dstack._internal.server.services.permissions import get_default_permissions from dstack._internal.server.settings import DEFAULT_PROJECT_NAME from dstack._internal.utils.common import get_current_datetime, run_async @@ -633,7 +636,11 @@ async def create_project_model( actor=events.UserActor.from_user(owner), targets=[events.Target.from_model(project)], ) - await session.commit() + async with advisory_lock_ctx(session, get_db().dialect_name, GLOBAL_EXPORTS_LOCK_NAMESPACE): + res = await session.execute(select(ExportModel.id).where(ExportModel.is_global == True)) + for export_id in res.scalars().all(): + session.add(ImportModel(project=project, export_id=export_id)) + await session.commit() # commit before releasing the lock return project diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index a1deb4fadb..249780fcd8 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -594,10 +594,12 @@ async def create_export( exported_fleets: list[FleetModel], exported_gateways: Optional[list[GatewayModel]] = None, name: str = "test-export", + is_global: bool = False, ) -> ExportModel: export = ExportModel( name=name, project=exporter_project, + is_global=is_global, imports=[ImportModel(project=project) for project in importer_projects], exported_fleets=[ExportedFleetModel(fleet=fleet) for fleet in exported_fleets], exported_gateways=[ diff --git a/src/dstack/api/server/_exports.py b/src/dstack/api/server/_exports.py index ad18bbd7bb..f23016011d 100644 --- a/src/dstack/api/server/_exports.py +++ b/src/dstack/api/server/_exports.py @@ -25,12 +25,14 @@ def create( project_name: str, name: str, *, + is_global: bool = False, importer_projects: List[str] = [], exported_fleets: List[str] = [], exported_gateways: List[str] = [], ) -> Export: body = CreateExportRequest( name=name, + is_global=is_global, importer_projects=importer_projects, exported_fleets=exported_fleets, exported_gateways=exported_gateways, @@ -46,6 +48,8 @@ def update( project_name: str, name: str, *, + set_global: bool = False, + unset_global: bool = False, add_importer_projects: List[str] = [], remove_importer_projects: List[str] = [], add_exported_fleets: List[str] = [], @@ -55,6 +59,8 @@ def update( ) -> Export: body = UpdateExportRequest( name=name, + set_global=set_global, + unset_global=unset_global, add_importer_projects=add_importer_projects, remove_importer_projects=remove_importer_projects, add_exported_fleets=add_exported_fleets, diff --git a/src/tests/_internal/server/routers/test_exports.py b/src/tests/_internal/server/routers/test_exports.py index 5c0e798816..9a21f2cb4b 100644 --- a/src/tests/_internal/server/routers/test_exports.py +++ b/src/tests/_internal/server/routers/test_exports.py @@ -7,7 +7,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.users import GlobalRole, ProjectRole -from dstack._internal.server.models import ExportModel +from dstack._internal.server.models import ExportModel, ImportModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( create_backend, @@ -57,6 +57,21 @@ async def test_returns_403_if_not_admin(self, session: AsyncSession, client: Asy ) assert response.status_code == 403 + async def test_create_global_returns_403_if_not_global_admin( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + response = await client.post( + f"/api/project/{project.name}/exports/create", + headers=get_auth_headers(user.token), + json={"name": "my-export", "is_global": True}, + ) + assert response.status_code == 403 + @pytest.mark.parametrize( ("global_role", "importer_project_role"), [(GlobalRole.ADMIN, None), (GlobalRole.USER, ProjectRole.ADMIN)], @@ -108,6 +123,7 @@ async def test_creates_export( assert response.status_code == 200 export_response = response.json() assert export_response["name"] == "test-export" + assert export_response["is_global"] == False assert len(export_response["imports"]) == 1 assert export_response["imports"][0]["project_name"] == "ImporterProject" assert len(export_response["exported_fleets"]) == 1 @@ -141,6 +157,33 @@ async def test_creates_empty_export(self, session: AsyncSession, client: AsyncCl res = await session.execute(select(ExportModel).where(ExportModel.name == "empty-export")) assert res.scalar() is not None + async def test_creates_global_export(self, session: AsyncSession, client: AsyncClient): + admin = await create_user(session=session, global_role=GlobalRole.ADMIN) + exporter_project = await create_project( + session=session, name="ExporterProject", owner=admin + ) + await add_project_member( + session=session, project=exporter_project, user=admin, project_role=ProjectRole.ADMIN + ) + project_a = await create_project(session=session, name="ProjectA", owner=admin) + project_b = await create_project(session=session, name="ProjectB", owner=admin) + + response = await client.post( + f"/api/project/{exporter_project.name}/exports/create", + headers=get_auth_headers(admin.token), + json={"name": "my-export", "is_global": True}, + ) + assert response.status_code == 200 + data = response.json() + assert data["is_global"] is True + imported_names = {imp["project_name"] for imp in data["imports"]} + assert imported_names == {project_a.name, project_b.name} + assert exporter_project.name not in imported_names + res = await session.execute(select(func.count()).select_from(ExportModel)) + assert res.scalar_one() == 1 + res = await session.execute(select(func.count()).select_from(ImportModel)) + assert res.scalar_one() == 2 + @pytest.mark.parametrize( "body,error", [ @@ -307,6 +350,28 @@ async def test_rejects_invalid_export( res = await session.execute(select(func.count()).select_from(ExportModel)) assert res.scalar_one() == 0 + async def test_rejects_invalid_global_export_with_importer_projects( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.ADMIN) + project = await create_project(session=session, name="ExporterProject", owner=user) + response = await client.post( + f"/api/project/{project.name}/exports/create", + headers=get_auth_headers(user.token), + json={ + "name": "test-export", + "is_global": True, + "importer_projects": ["ImporterProject"], + }, + ) + assert response.status_code == 400 + assert ( + "Do not specify any importer projects when creating a global export" + in response.json()["detail"][0]["msg"] + ) + res = await session.execute(select(func.count()).select_from(ExportModel)) + assert res.scalar_one() == 0 + async def test_rejects_export_on_name_conflict( self, session: AsyncSession, client: AsyncClient ): @@ -786,6 +851,41 @@ async def test_can_add_same_entities_as_existing_deleted_ones( "Gateways {'not-exported-gateway'} are listed for both addition and removal. Cannot add and remove at the same time", id="add-remove-same-gateway", ), + pytest.param( + { + "name": "test-export", + "set_global": True, + "unset_global": True, + }, + "Cannot set and unset global at the same time", + id="set-and-unset-global", + ), + pytest.param( + { + "name": "test-export", + "unset_global": True, + }, + "The export is already not global", + id="unset-non-global", + ), + pytest.param( + { + "name": "test-export", + "set_global": True, + "add_importer_projects": ["NotImporterProject"], + }, + "Cannot change global status and add/remove importers at the same time", + id="set-global-with-importer-changes", + ), + pytest.param( + { + "name": "test-export", + "unset_global": True, + "remove_importer_projects": ["ImporterProject"], + }, + "Cannot change global status and add/remove importers at the same time", + id="unset-global-with-importer-changes", + ), ], ) async def test_rejects_invalid_update( @@ -885,6 +985,200 @@ async def test_rejects_invalid_update( assert response.status_code == 200 assert response.json() == canonical_exports + async def test_set_global_returns_403_if_not_global_admin( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_export( + session=session, + exporter_project=project, + importer_projects=[], + exported_fleets=[], + name="my-export", + ) + + response = await client.post( + f"/api/project/{project.name}/exports/update", + headers=get_auth_headers(user.token), + json={"name": "my-export", "set_global": True}, + ) + assert response.status_code == 403 + + async def test_project_admin_can_unset_global( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_export( + session=session, + exporter_project=project, + importer_projects=[], + exported_fleets=[], + name="my-export", + is_global=True, + ) + + response = await client.post( + f"/api/project/{project.name}/exports/update", + headers=get_auth_headers(user.token), + json={"name": "my-export", "unset_global": True}, + ) + assert response.status_code == 200 + assert response.json()["is_global"] is False + + async def test_set_global(self, session: AsyncSession, client: AsyncClient): + admin = await create_user(session=session, global_role=GlobalRole.ADMIN) + exporter_project = await create_project( + session=session, name="ExporterProject", owner=admin + ) + already_importing = await create_project( + session=session, name="AlreadyImporting", owner=admin + ) + not_yet_importing = await create_project( + session=session, name="NotYetImporting", owner=admin + ) + export = await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[already_importing], + exported_fleets=[], + name="my-export", + ) + + response = await client.post( + f"/api/project/{exporter_project.name}/exports/update", + headers=get_auth_headers(admin.token), + json={"name": "my-export", "set_global": True}, + ) + assert response.status_code == 200 + data = response.json() + assert data["is_global"] is True + imported_names = {imp["project_name"] for imp in data["imports"]} + assert imported_names == {already_importing.name, not_yet_importing.name} + assert exporter_project.name not in imported_names + await session.refresh(export, ["imports"]) + assert len(export.imports) == 2 + + async def test_unset_global_keeps_imports(self, session: AsyncSession, client: AsyncClient): + admin = await create_user(session=session, global_role=GlobalRole.ADMIN) + exporter_project = await create_project( + session=session, name="ExporterProject", owner=admin + ) + importer = await create_project(session=session, name="ImporterProject", owner=admin) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer], + exported_fleets=[], + name="my-export", + is_global=True, + ) + + response = await client.post( + f"/api/project/{exporter_project.name}/exports/update", + headers=get_auth_headers(admin.token), + json={"name": "my-export", "unset_global": True}, + ) + assert response.status_code == 200 + data = response.json() + assert data["is_global"] is False + # imports still present + assert len(data["imports"]) == 1 + assert data["imports"][0]["project_name"] == importer.name + + async def test_cannot_remove_importer_from_global_export( + self, session: AsyncSession, client: AsyncClient + ): + admin = await create_user(session=session, global_role=GlobalRole.ADMIN) + exporter_project = await create_project(session=session, owner=admin) + importer = await create_project(session=session, name="ImporterProject", owner=admin) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer], + exported_fleets=[], + name="my-export", + is_global=True, + ) + + response = await client.post( + f"/api/project/{exporter_project.name}/exports/update", + headers=get_auth_headers(admin.token), + json={ + "name": "my-export", + "remove_importer_projects": [importer.name], + }, + ) + assert response.status_code == 400 + assert ( + "Cannot remove importers from a global export" in response.json()["detail"][0]["msg"] + ) + + async def test_can_add_missing_importer_to_global_export( + self, session: AsyncSession, client: AsyncClient + ): + """ + Global exports should always be imported in all projects, but in case this invariant + is ever violated (e.g., due to bugs or unforeseen race conditions), adding a missing + importer is still allowed. + """ + admin = await create_user(session=session, global_role=GlobalRole.ADMIN) + exporter_project = await create_project(session=session, owner=admin) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[], + exported_fleets=[], + name="my-export", + is_global=True, + ) + importer = await create_project(session=session, name="ImporterProject", owner=admin) + + response = await client.post( + f"/api/project/{exporter_project.name}/exports/update", + headers=get_auth_headers(admin.token), + json={ + "name": "my-export", + "add_importer_projects": [importer.name], + }, + ) + assert response.status_code == 200 + export_response = response.json() + assert len(export_response["imports"]) == 1 + assert export_response["imports"][0]["project_name"] == importer.name + + async def test_set_global_already_global_returns_400( + self, session: AsyncSession, client: AsyncClient + ): + admin = await create_user(session=session, global_role=GlobalRole.ADMIN) + project = await create_project(session=session, owner=admin) + await add_project_member( + session=session, project=project, user=admin, project_role=ProjectRole.ADMIN + ) + await create_export( + session=session, + exporter_project=project, + importer_projects=[], + exported_fleets=[], + name="my-export", + is_global=True, + ) + + response = await client.post( + f"/api/project/{project.name}/exports/update", + headers=get_auth_headers(admin.token), + json={"name": "my-export", "set_global": True}, + ) + assert response.status_code == 400 + assert "The export is already global" in response.json()["detail"][0]["msg"] + class TestDeleteExport: async def test_returns_403_if_not_authenticated(self, client: AsyncClient): @@ -955,6 +1249,33 @@ async def test_returns_400_for_nonexistent_export( assert response.status_code == 400 assert response.json()["detail"][0]["code"] == "resource_not_exists" + async def test_project_admin_can_delete_global_export( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + export = await create_export( + session=session, + exporter_project=project, + importer_projects=[], + exported_fleets=[], + name="my-export", + is_global=True, + ) + + response = await client.post( + f"/api/project/{project.name}/exports/delete", + headers=get_auth_headers(user.token), + json={"name": export.name}, + ) + assert response.status_code == 200 + + res = await session.execute(select(ExportModel)) + assert res.scalar() is None + class TestListExports: async def test_returns_403_if_not_authenticated(self, client: AsyncClient): diff --git a/src/tests/_internal/server/routers/test_imports.py b/src/tests/_internal/server/routers/test_imports.py index d4ef4a931f..c162d0d8ec 100644 --- a/src/tests/_internal/server/routers/test_imports.py +++ b/src/tests/_internal/server/routers/test_imports.py @@ -139,6 +139,40 @@ async def assert_not_found(export_project_name, export_name): # Import not found await assert_not_found(export_project_name="ExporterProject", export_name="test-export") + async def test_cannot_delete_import_of_global_export( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + importer_project = await create_project( + session=session, name="ImporterProject", owner=user + ) + await add_project_member( + session=session, project=importer_project, user=user, project_role=ProjectRole.ADMIN + ) + + exporter_project = await create_project( + session=session, name="ExporterProject", owner=user + ) + export = await create_export( + session=session, + is_global=True, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + name="test-export", + ) + + response = await client.post( + f"/api/project/{importer_project.name}/imports/delete", + headers=get_auth_headers(user.token), + json={"export_name": export.name, "export_project_name": exporter_project.name}, + ) + assert response.status_code == 400 + assert ( + response.json()["detail"][0]["msg"] + == "'ExporterProject/test-export' is a global export, cannot stop importing" + ) + class TestListImports: async def test_returns_403_if_not_authenticated(self, client: AsyncClient): diff --git a/src/tests/_internal/server/routers/test_projects.py b/src/tests/_internal/server/routers/test_projects.py index 6d7bcca0ed..67afa77390 100644 --- a/src/tests/_internal/server/routers/test_projects.py +++ b/src/tests/_internal/server/routers/test_projects.py @@ -1155,6 +1155,46 @@ async def test_creates_private_project_explicitly( project = res.scalar_one() assert project.is_public is False + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_new_project_imports_global_exports( + self, session: AsyncSession, client: AsyncClient + ): + exporter_project = await create_project(session=session, name="ExporterProject") + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[], + exported_fleets=[], + name="non-global", + is_global=False, + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[], + exported_fleets=[], + name="global-export", + is_global=True, + ) + user = await create_user(session=session, global_role=GlobalRole.USER) + + response = await client.post( + "/api/projects/create", + headers=get_auth_headers(user.token), + json={"project_name": "new-project"}, + ) + assert response.status_code == 200 + + response = await client.post( + "/api/project/new-project/imports/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + imports = response.json() + assert len(imports) == 1 + assert imports[0]["export"]["name"] == "global-export" + class TestDeleteProject: @pytest.mark.asyncio diff --git a/src/tests/_internal/server/services/test_config.py b/src/tests/_internal/server/services/test_config.py index ea03141a17..81265445e8 100644 --- a/src/tests/_internal/server/services/test_config.py +++ b/src/tests/_internal/server/services/test_config.py @@ -9,10 +9,11 @@ from dstack._internal.core.backends.aws.configurator import DEFAULT_REGIONS from dstack._internal.server import settings -from dstack._internal.server.models import BackendModel, ProjectModel +from dstack._internal.server.models import BackendModel, ImportModel, ProjectModel from dstack._internal.server.services.config import ServerConfigManager from dstack._internal.server.testing.common import ( create_backend, + create_export, create_project, create_user, ) @@ -213,3 +214,36 @@ async def test_forces_update_when_current_backend_config_is_unavailable( manager.load_config() await manager.apply_config(session, owner) update_backend.assert_awaited_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_new_project_imports_global_exports( + self, test_db, session: AsyncSession, tmp_path: Path + ): + owner = await create_user(session=session, name="test_owner") + exporter_project = await create_project(session=session, owner=owner, name="exporter") + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[], + exported_fleets=[], + name="global-export", + is_global=True, + ) + config_filepath = tmp_path / "config.yml" + config = {"projects": [{"name": "new-project"}]} + with open(config_filepath, "w+") as f: + yaml.dump(config, f) + with patch.object(settings, "SERVER_CONFIG_FILE_PATH", config_filepath): + manager = ServerConfigManager() + manager.load_config() + await manager.apply_config(session, owner) + new_project_res = await session.execute( + select(ProjectModel).where(ProjectModel.name == "new-project") + ) + new_project = new_project_res.scalar_one() + imports_res = await session.execute( + select(ImportModel).where(ImportModel.project_id == new_project.id) + ) + imports = imports_res.scalars().all() + assert len(imports) == 1