Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions src/dstack/_internal/server/services/exports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, nullcontext
from typing import Optional

from sqlalchemy import func, select
Expand Down Expand Up @@ -36,11 +36,7 @@
)
from dstack._internal.server.services.fleets import get_fleet_spec, list_project_fleet_models
from dstack._internal.server.services.gateways import list_project_gateway_models
from dstack._internal.server.services.locking import (
advisory_lock_ctx,
get_locker,
string_to_lock_id,
)
from dstack._internal.server.services.locking import get_locker, string_to_lock_id
from dstack._internal.server.services.projects import (
get_user_project_role,
list_project_models,
Expand Down Expand Up @@ -124,17 +120,35 @@ async def create_export(
" Global exports are automatically imported in all projects"
)

lock_namespace = f"export_names_{project.name}"
export_names_lock_namespace = f"export_names_{project.name}"
if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
elif is_db_postgres():
await session.execute(
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
select(func.pg_advisory_xact_lock(string_to_lock_id(export_names_lock_namespace)))
)
lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
export_names_lock, _ = get_locker(get_db().dialect_name).get_lockset(
export_names_lock_namespace
)

async with lock:
if is_global:
if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
elif is_db_postgres():
await session.execute(
select(
func.pg_advisory_xact_lock(string_to_lock_id(GLOBAL_EXPORTS_LOCK_NAMESPACE))
)
)
global_exports_lock, _ = get_locker(get_db().dialect_name).get_lockset(
GLOBAL_EXPORTS_LOCK_NAMESPACE
)
else:
global_exports_lock = nullcontext()

async with export_names_lock, global_exports_lock:
if await export_exists(session, project, name):
raise ResourceExistsError(
f"Export {name!r} already exists in project {project.name!r}"
Expand All @@ -150,15 +164,10 @@ async def create_export(
await add_importer_projects(session, user, export, importer_project_names)
await add_exported_fleets(session, export, exported_fleet_names)
await add_exported_gateways(session, export, exported_gateway_names)
session.add(export)
if is_global:
async with advisory_lock_ctx(
session, get_db().dialect_name, GLOBAL_EXPORTS_LOCK_NAMESPACE
):
await set_as_global(session, export, user)
await session.commit() # commit before releasing the lock
else:
await session.commit()
await set_as_global(session, export, user)
session.add(export)
await session.commit()
return export_model_to_export(export)


Expand All @@ -176,7 +185,26 @@ async def update_export(
add_exported_gateway_names: list[str],
remove_exported_gateway_names: list[str],
) -> Export:
async with get_export_model_by_name_for_update(session, project, name) as export:
if set_global:
if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
elif is_db_postgres():
await session.execute(
select(
func.pg_advisory_xact_lock(string_to_lock_id(GLOBAL_EXPORTS_LOCK_NAMESPACE))
)
)
global_exports_lock, _ = get_locker(get_db().dialect_name).get_lockset(
GLOBAL_EXPORTS_LOCK_NAMESPACE
)
else:
global_exports_lock = nullcontext()

async with (
global_exports_lock,
get_export_model_by_name_for_update(session, project, name) as export,
):
if export is None:
raise ResourceNotExistsError(f"Export {name!r} not found in project {project.name!r}")

Expand Down Expand Up @@ -237,13 +265,8 @@ async def update_export(
if unset_global:
await unset_as_global(export)
if set_global:
async with advisory_lock_ctx(
session, get_db().dialect_name, GLOBAL_EXPORTS_LOCK_NAMESPACE
):
await set_as_global(session, export, user)
await session.commit() # commit before releasing the lock
else:
await session.commit()
await set_as_global(session, export, user)
await session.commit()
return export_model_to_export(export)


Expand Down
30 changes: 21 additions & 9 deletions src/dstack/_internal/server/services/locking.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,30 @@ async def advisory_lock_ctx(
bind: Union[AsyncConnection, AsyncSession], dialect_name: str, resource: str
):
"""
Take a global lock on `resource` across all dstack server replicas.
In-memory lock for SQLite, advisory lock for Postgres.
Acquire a Postgres advisory lock on `resource`. No-op for SQLite.

**NOTE**: The lock must be released by the same database connection that acquired it.
Attempts to release in a different connection will fail.

To prevent unreleased locks:

1. When possible, prefer using `pg_advisory_xact_lock` instead of this context manager.
`pg_advisory_xact_lock` is automatically released at the end of transaction.

1. Prefer using `AsyncConnection` as `bind`.

1. If using `AsyncSession` as `bind`, **do not** commit before exiting from the context manager.
Committing will prompt `AsyncSession` to start a new transaction for releasing the lock,
which may be assigned to a different database connection, which will fail to release.
"""

if dialect_name == "postgresql":
await bind.execute(select(func.pg_advisory_lock(string_to_lock_id(resource))))
lock, _ = get_locker(dialect_name).get_lockset(resource)
async with lock:
try:
yield
finally:
if dialect_name == "postgresql":
await bind.execute(select(func.pg_advisory_unlock(string_to_lock_id(resource))))
try:
yield
finally:
if dialect_name == "postgresql":
await bind.execute(select(func.pg_advisory_unlock(string_to_lock_id(resource))))


@asynccontextmanager
Expand Down
35 changes: 25 additions & 10 deletions src/dstack/_internal/server/services/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from dstack._internal.core.models.runs import RunStatus
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.const import GLOBAL_EXPORTS_LOCK_NAMESPACE
from dstack._internal.server.db import get_db
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
from dstack._internal.server.models import (
ExportModel,
FleetModel,
Expand All @@ -44,7 +44,10 @@
from dstack._internal.server.services.backends import (
get_backend_config_without_creds_from_backend_model,
)
from dstack._internal.server.services.locking import advisory_lock_ctx
from dstack._internal.server.services.locking import (
get_locker,
string_to_lock_id,
)
from dstack._internal.server.services.permissions import get_default_permissions
from dstack._internal.server.settings import DEFAULT_PROJECT_NAME
from dstack._internal.utils.common import get_current_datetime, run_async
Expand Down Expand Up @@ -629,18 +632,30 @@ async def create_project_model(
is_public=is_public,
templates_repo=templates_repo,
)
session.add(project)
events.emit(
session,
"Project created",
actor=events.UserActor.from_user(owner),
targets=[events.Target.from_model(project)],

if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
elif is_db_postgres():
await session.execute(
select(safunc.pg_advisory_xact_lock(string_to_lock_id(GLOBAL_EXPORTS_LOCK_NAMESPACE)))
)
global_exports_lock, _ = get_locker(get_db().dialect_name).get_lockset(
GLOBAL_EXPORTS_LOCK_NAMESPACE
)
async with advisory_lock_ctx(session, get_db().dialect_name, GLOBAL_EXPORTS_LOCK_NAMESPACE):

async with global_exports_lock:
res = await session.execute(select(ExportModel.id).where(ExportModel.is_global == True))
for export_id in res.scalars().all():
session.add(ImportModel(project=project, export_id=export_id))
await session.commit() # commit before releasing the lock
session.add(project)
events.emit(
session,
"Project created",
actor=events.UserActor.from_user(owner),
targets=[events.Target.from_model(project)],
)
await session.commit()
return project


Expand Down
Loading