diff --git a/tests/test_distinct.py b/tests/test_distinct.py new file mode 100644 index 000000000..47488c555 --- /dev/null +++ b/tests/test_distinct.py @@ -0,0 +1,264 @@ +import pytest + +from tests.testmodels import Author, Book, DefaultOrdered, SourceFieldPk, Tournament +from tortoise.contrib import test +from tortoise.contrib.test.condition import NotIn +from tortoise.exceptions import OperationalError +from tortoise.functions import Count + +# --------------------------------------------------------------------------- +# Basic DISTINCT (all databases) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_distinct_no_args(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + tournaments = await Tournament.all().distinct() + assert len(tournaments) == 2 + + +# --------------------------------------------------------------------------- +# DISTINCT ON (PostgreSQL only) +# --------------------------------------------------------------------------- + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_single_field(db): + tournament_1 = await Tournament.create(name="1", desc="1") + await Tournament.create(name="1", desc="2") + await Tournament.create(name="1", desc="3") + + tournaments = await Tournament.all().distinct("name") + assert tournaments == [tournament_1] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_single_field_with_order_by(db): + await Tournament.create(name="1", desc="1") + await Tournament.create(name="1", desc="2") + tournament_3 = await Tournament.create(name="1", desc="3") + + tournaments = await Tournament.all().distinct("name").order_by("name", "-desc") + assert tournaments == [tournament_3] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_multiple_fields(db): + tournament_1 = await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="a") + tournament_3 = await Tournament.create(name="2", desc="b") + + tournaments = await Tournament.all().distinct("name", "desc").order_by("name", "desc") + assert tournaments == [tournament_1, tournament_3] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_single_field(db): + """values_list selects one field, same as DISTINCT ON field.""" + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("name", flat=True) + assert tournaments == ["1", "2"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_multiple_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("name", "desc") + assert tournaments == [("1", "a"), ("2", "c")] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_extra_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("desc", flat=True) + assert tournaments == ["a", "c"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_extra_field_respects_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = ( + await Tournament.all() + .distinct("name") + .order_by("name", "-desc") + .values_list("desc", flat=True) + ) + assert tournaments == ["b", "c"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_single_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("name") + assert tournaments == [{"name": "1"}, {"name": "2"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_multiple_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("name", "desc") + assert tournaments == [{"name": "1", "desc": "a"}, {"name": "2", "desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_extra_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("desc") + assert tournaments == [{"desc": "a"}, {"desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_extra_field_respects_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").order_by("name", "-desc").values("desc") + assert tournaments == [{"desc": "b"}, {"desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_same_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").only("name") + assert [t.name for t in tournaments] == ["1", "2"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_extra_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").only("name", "desc") + assert [(t.name, t.desc) for t in tournaments] == [("1", "a"), ("2", "c")] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_with_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = ( + await Tournament.all().distinct("name").order_by("name", "-desc").only("name", "desc") + ) + assert [(t.name, t.desc) for t in tournaments] == [("1", "b"), ("2", "c")] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_filter_by_model(db): + tournament_1 = await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + tournament_2 = await Tournament.create(name="2", desc="c") + tournaments = await Tournament.filter(name__in=["1", "2"]).distinct("name") + assert [tournament_1, tournament_2] == tournaments + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_source_field(db): + await SourceFieldPk.create(name="1") + await SourceFieldPk.create(name="2") + await SourceFieldPk.all().distinct("id") + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_annotate_by_model(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + tournaments = ( + await Tournament.annotate(count_name=Count("name")).distinct("name").order_by("name") + ) + assert [1, 1] == [tournaments[0].count_name, tournaments[0].count_name] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_default_ordered(db): + await DefaultOrdered.create(one="1", second=1) + await DefaultOrdered.create(one="2", second=2) + await DefaultOrdered.all().distinct("one") + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_by_relation(db): + author_1 = await Author.create(name="1") + author_2 = await Author.create(name="1") + await Book.create(name="1", rating=1, subject="1", author=author_1) + await Book.create(name="2", rating=2, subject="2", author=author_2) + books = await Book.all().distinct("author__name") + assert len(books) == 1 + + +# --------------------------------------------------------------------------- +# DISTINCT ON validation errors +# --------------------------------------------------------------------------- + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_invalid_order_by(db): + await Tournament.create(name="1") + with pytest.raises(OperationalError): + await Tournament.all().distinct("name").order_by("desc") + + +@test.requireCapability(dialect=NotIn("postgres")) +@pytest.mark.asyncio +async def test_distinct_on_not_supported_outside_postgres(db): + with pytest.raises(OperationalError): + await Tournament.all().distinct("name") + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_invalid_default_ordered(db): + await DefaultOrdered.create(one="1", second=1) + await DefaultOrdered.create(one="2", second=2) + with pytest.raises(OperationalError): + await DefaultOrdered.all().distinct("second") diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 9865b78fc..10db19140 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -738,9 +738,10 @@ def _make_prefetch_queries(self) -> None: relation_field = self.model._meta.fields_map[field_name] related_model: type[Model] = relation_field.related_model # type: ignore related_query = related_model.all().using_db(self.db) - related_query.query = copy(related_query.model._meta.basequery) # type:ignore[assignment] if forwarded_prefetches: related_query = related_query.prefetch_related(*forwarded_prefetches) + if field_name not in self._prefetch_queries: + related_query.query = copy(related_query.model._meta.basequery) self._prefetch_queries.setdefault(field_name, []).append((to_attr, related_query)) async def _do_prefetch( diff --git a/tortoise/queryset.py b/tortoise/queryset.py index aecffc42d..2477c5c79 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -8,6 +8,7 @@ from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count +from pypika_tortoise.dialects import PostgreSQLQueryBuilder from pypika_tortoise.functions import Cast from pypika_tortoise.queries import QueryBuilder, _SetOperation from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper @@ -42,6 +43,7 @@ # Empty placeholder - Should never be edited. QUERY: QueryBuilder = QueryBuilder() +POSTGRES_QUERY: PostgreSQLQueryBuilder = PostgreSQLQueryBuilder() if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model @@ -84,6 +86,7 @@ def values( class _ChooseDBMixin(Generic[MODEL]): _db: BaseDBAsyncClient | None model: type[MODEL] + query: QueryBuilder | PostgreSQLQueryBuilder def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: """ @@ -99,9 +102,23 @@ def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: db = router.db_for_read(self.model) return db or self.model._meta.db + def _apply_db(self, db: BaseDBAsyncClient | None) -> None: + """ + Set the database connection for this query and update the query builder dialect. + + Assigns ``db`` to ``_db`` and, when the connection targets PostgreSQL, + replaces the default ``query`` placeholder with ``POSTGRES_QUERY`` so + that subsequent query-building calls produce PostgreSQL-specific SQL. + + :param db: The database connection to use for this query. + """ + self._db = db + if db is not None and hasattr(self, "query") and db.capabilities.dialect == "postgres": + self.query = POSTGRES_QUERY + def _choose_db_if_not_chosen(self, for_write: bool = False) -> None: if self._db is None: - self._db = self._choose_db(for_write) + self._apply_db(self._choose_db(for_write)) class AwaitableQuery(_ChooseDBMixin[MODEL], Generic[MODEL]): @@ -119,7 +136,7 @@ class AwaitableQuery(_ChooseDBMixin[MODEL], Generic[MODEL]): def __init__(self, model: type[MODEL]) -> None: self._joined_tables: list[Table] = [] self.model: type[MODEL] = model - self.query: QueryBuilder = QUERY + self.query: QueryBuilder | PostgreSQLQueryBuilder = QUERY self._db: BaseDBAsyncClient = None # type: ignore self._capabilities: Capabilities | None = None self._annotations: dict[str, Expression | Term] = {} @@ -274,6 +291,64 @@ def resolve_ordering( self.query = self.query.orderby(field, order=ordering[1]) + def resolve_distinct( + self, + distinct: bool, + distinct_on: list[str], + orderings: Iterable[tuple[str, str | Order]], + annotations: dict[str, Term | Expression], + ) -> None: + self.query._distinct = distinct + if isinstance(self.query, PostgreSQLQueryBuilder): + self.query._distinct_on = [] + if not distinct: + return + if not orderings and self.model._meta.ordering and not annotations: + orderings = self.model._meta.ordering + if distinct_on: + if not isinstance(self.query, PostgreSQLQueryBuilder): + raise OperationalError("DISTINCT ON is only supported by PostgreSQL") + ordering_fields = [ordering[0] for ordering in orderings] + len_ordering_fields = len(ordering_fields) + for i, field in enumerate(distinct_on): + if ordering_fields and (i >= len_ordering_fields or ordering_fields[i] != field): + raise OperationalError( + f"DISTINCT ON fields must match the leading ORDER BY fields. " + f"Expected ORDER BY to start with {distinct_on!r}." + ) + distinct_on_by_source_field = [] + for field_name in distinct_on: + field_object = self.model._meta.fields_map.get(field_name) + part_after = field_name + related_table = self.model._meta.basetable + related_model: type[Model] = self.model + while part_after: + related_field_name, __, part_after = part_after.partition("__") + if related_field_name in related_model._meta.fetch_fields: + related_field = cast( + RelationalField, self.model._meta.fields_map[related_field_name] + ) + related_table = self._join_table_by_field( + related_table, related_field_name, related_field + ) + related_model = related_field.model + else: + field_object = related_model._meta.fields_map.get(related_field_name) + + if not field_object: + raise FieldError( + f"Unknown field {related_field_name} for model {related_model.__name__}" + ) + related_table_field = related_table[ + field_object.source_field or related_field_name + ] + if func := field_object.get_for_dialect( + related_model._meta.db.capabilities.dialect, "function_cast" + ): + related_table_field = func(field_object, related_table_field) + distinct_on_by_source_field.append(related_table_field) + self.query.distinct_on(*distinct_on_by_source_field) + def _resolve_annotate(self, fields_for_select: Collection[str] | None = None) -> bool: if not self._annotations: return False @@ -362,6 +437,7 @@ def __init__(self, model: type[MODEL]) -> None: self._filter_kwargs: dict[str, Any] = {} self._orderings: list[tuple[str, Any]] = [] self._distinct: bool = False + self._distinct_on: list[str] = [] self._having: dict[str, Any] = {} self._fields_for_select: tuple[str, ...] = () self._group_bys: tuple[str, ...] = () @@ -387,7 +463,7 @@ def _clone(self) -> QuerySet[MODEL]: queryset._prefetch_queries = copy(self._prefetch_queries) queryset._single = self._single queryset._raise_does_not_exist = self._raise_does_not_exist - queryset._db = self._db + queryset._apply_db(self._db) queryset._limit = self._limit queryset._offset = self._offset queryset._fields_for_select = self._fields_for_select @@ -396,6 +472,7 @@ def _clone(self) -> QuerySet[MODEL]: queryset._joined_tables = copy(self._joined_tables) queryset._q_objects = copy(self._q_objects) queryset._distinct = self._distinct + queryset._distinct_on = copy(self._distinct_on) queryset._annotations = copy(self._annotations) queryset._having = copy(self._having) queryset._custom_filters = copy(self._custom_filters) @@ -580,15 +657,38 @@ def __getitem__(self, key: slice) -> QuerySet[MODEL]: queryset = queryset.limit(key.stop - start) return queryset - def distinct(self) -> QuerySet[MODEL]: + def distinct(self, *args: str) -> QuerySet[MODEL]: """ - Make QuerySet distinct. + Make QuerySet return distinct results. + + Without arguments, adds a plain ``DISTINCT`` to the query, which works on all databases + and is most useful with ``.values()`` or ``.values_list()``. + + With arguments (PostgreSQL only), generates ``DISTINCT ON (fields)`` which keeps one row + per unique combination of the given fields. ``ORDER BY`` is optional, but if specified + it must begin with the same fields in the same order as ``DISTINCT ON`` — otherwise an + :exc:`~tortoise.exceptions.OperationalError` is raised. + + Can be combined with ``.only()``, ``.values()``, and ``.values_list()`` — fields not + present in ``DISTINCT ON`` are taken from the row selected by the ordering. - Only makes sense in combination with a ``.values()`` or ``.values_list()`` as it - precedes all the fetched fields with a distinct. + .. code-block:: python3 + + # Plain DISTINCT — all databases + await Tournament.all().distinct().values("name") + + # DISTINCT ON without ORDER BY — PostgreSQL only + await Tournament.all().distinct("name") + + # DISTINCT ON with ORDER BY — ORDER BY must start with DISTINCT ON fields + await Tournament.all().distinct("name").order_by("name", "-desc") + + :param args: Field names for ``DISTINCT ON`` (PostgreSQL only). Omit for plain + ``DISTINCT``. """ queryset = self._clone() queryset._distinct = True + queryset._distinct_on = list(args) return queryset def union(self, *other_qs: QuerySet[Model], all: bool = False) -> UnionQuery[MODEL]: @@ -698,6 +798,7 @@ def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Lite group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + distinct_on=self._distinct_on, ) def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: @@ -753,6 +854,7 @@ def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + distinct_on=self._distinct_on, ) def delete(self) -> DeleteQuery: @@ -1115,7 +1217,7 @@ def using_db(self, _db: BaseDBAsyncClient | None) -> QuerySet[MODEL]: Useful for transactions workaround. """ queryset = self._clone() - queryset._db = _db if _db else queryset._db + queryset._apply_db(_db if _db else queryset._db) return queryset def _join_select_related(self, lookup_expression: str) -> tuple[type[Model], Table]: @@ -1264,12 +1366,16 @@ def _make_query(self) -> None: self._fields_for_select, ) self.resolve_filters() + self.resolve_distinct( + self._distinct, + self._distinct_on, + self._orderings, + self._annotations, + ) if self._limit is not None: self.query._limit = self.query._wrapper_cls(self._limit) if self._offset is not None: self.query._offset = self.query._wrapper_cls(self._offset) - if self._distinct: - self.query._distinct = True if self._select_for_update: self.query = self.query.for_update( self._select_for_update_nowait, @@ -1288,8 +1394,7 @@ def _make_query(self) -> None: self.query = self.query.use_index(*self._use_indexes) def __await__(self) -> Generator[Any, None, list[MODEL]]: - if self._db is None: - self._db = self._choose_db(self._select_for_update) # type: ignore + self._choose_db_if_not_chosen(self._select_for_update) self._make_query() return self._execute().__await__() @@ -1343,7 +1448,7 @@ def __init__( self._q_objects = q_objects self._annotations = annotations self._custom_filters = custom_filters - self._db = db + self._apply_db(db) self._limit = limit self._orderings = orderings @@ -1422,7 +1527,7 @@ def __init__( self._q_objects = q_objects self._annotations = annotations self._custom_filters = custom_filters - self._db = db + self._apply_db(db) self._limit = limit self._orderings = orderings @@ -1467,7 +1572,7 @@ def __init__( ) -> None: super().__init__(model) self._q_objects = q_objects - self._db = db + self._apply_db(db) self._annotations = annotations self._custom_filters = custom_filters self._force_indexes = force_indexes @@ -1540,7 +1645,7 @@ def __init__( self._custom_filters = custom_filters self._limit = limit self._offset = offset or 0 - self._db = db + self._apply_db(db) self._force_indexes = force_indexes self._use_indexes = use_indexes @@ -1595,8 +1700,7 @@ def _join_table_with_forwarded_fields( if field in self.model._meta.fetch_fields and not forwarded_fields: raise ValueError( - f'Selecting relation "{field}" is not possible, select concrete ' - "field on related model" + f'Selecting relation "{field}" is not possible, select concrete field on related model' ) field_object = cast(RelationalField, model._meta.fields_map.get(field)) @@ -1627,8 +1731,7 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: if field in self.model._meta.fetch_fields: raise ValueError( - f'Selecting relation "{field}" is not possible, select ' - "concrete field on related model" + f'Selecting relation "{field}" is not possible, select concrete field on related model' ) field_, __, forwarded_fields = field.partition("__") @@ -1704,6 +1807,7 @@ class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): "_force_indexes", "_use_indexes", "_fields_to_select_sql", + "_distinct_on", ) def __init__( @@ -1724,6 +1828,7 @@ def __init__( group_bys: tuple[str, ...], force_indexes: set[str], use_indexes: set[str], + distinct_on: list[str], ) -> None: super().__init__(model, annotations) if flat and (len(fields_for_select_list) != 1): @@ -1741,10 +1846,11 @@ def __init__( self._raise_does_not_exist = raise_does_not_exist self._fields_for_select_list = fields_for_select_list self._flat = flat - self._db = db + self._apply_db(db) self._group_bys = group_bys self._force_indexes = force_indexes self._use_indexes = use_indexes + self._distinct_on = distinct_on self._fields_to_select_sql = { *self._fields_for_select_list, *(key for key, value in self.fields.items() if value in self._fields_for_select_list), @@ -1765,12 +1871,16 @@ def _make_query(self) -> None: fields_for_select=self._fields_for_select_list, ) self.resolve_filters(self._fields_to_select_sql) + self.resolve_distinct( + self._distinct, + self._distinct_on, + self._orderings, + self._annotations, + ) if self._limit: self.query._limit = self.query._wrapper_cls(self._limit) if self._offset: self.query._offset = self.query._wrapper_cls(self._offset) - if self._distinct: - self.query._distinct = True if self._group_bys: self.query._groupbys = self._resolve_group_bys(*self._group_bys) @@ -1837,6 +1947,7 @@ class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): "_group_bys", "_force_indexes", "_use_indexes", + "_distinct_on", ) def __init__( @@ -1856,6 +1967,7 @@ def __init__( group_bys: tuple[str, ...], force_indexes: set[str], use_indexes: set[str], + distinct_on: list[str], ) -> None: super().__init__(model, annotations) self._fields_for_select = fields_for_select @@ -1867,10 +1979,11 @@ def __init__( self._q_objects = q_objects self._single = single self._raise_does_not_exist = raise_does_not_exist - self._db = db + self._apply_db(db) self._group_bys = group_bys self._force_indexes = force_indexes self._use_indexes = use_indexes + self._distinct_on = distinct_on def _make_query(self) -> None: self._joined_tables = [] @@ -1887,6 +2000,12 @@ def _make_query(self) -> None: fields_for_select=self._fields_for_select.keys(), ) self.resolve_filters() + self.resolve_distinct( + self._distinct, + self._distinct_on, + self._orderings, + self._annotations, + ) # remove annotations that are not in fields_for_select self.query._selects = [ @@ -1897,8 +2016,6 @@ def _make_query(self) -> None: self.query._limit = self.query._wrapper_cls(self._limit) if self._offset: self.query._offset = self.query._wrapper_cls(self._offset) - if self._distinct: - self.query._distinct = True if self._group_bys: self.query._groupbys = self._resolve_group_bys(*self._group_bys) @@ -1963,7 +2080,7 @@ class RawSQLQuery(AwaitableQuery): def __init__(self, model: type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: super().__init__(model) self._sql = sql - self._db = db + self._apply_db(db) async def _execute(self) -> Any: instance_list = await self._db.executor_class( @@ -2094,7 +2211,7 @@ def __init__( self._objects = objects self._ignore_conflicts = ignore_conflicts self._batch_size = batch_size - self._db = db + self._apply_db(db) self._update_fields = update_fields self._on_conflict = on_conflict @@ -2298,7 +2415,7 @@ def __init__( ) -> None: super().__init__(model) self._union_query = union_query - self._db = db + self._apply_db(db) def _make_query(self) -> None: self._union_query._make_query() @@ -2344,11 +2461,11 @@ def __init__( all: bool = False, ): self.model = model - self.query = QUERY + self.query: QueryBuilder | PostgreSQLQueryBuilder = QUERY self._models: set[type[Model]] = {model, *(qs.model for qs in querysets)} self._union_query: QueryBuilder | _SetOperation | None = None self._selects: list[str] = [] - self._db = db + self._apply_db(db) self._qs = querysets self._all = all self._orderings: list[tuple[str, Order]] | None = None @@ -2429,7 +2546,7 @@ def _clone(self) -> UnionQuery[MODEL]: union._models = self._models union._union_query = None union._selects = self._selects - union._db = self._db + union._apply_db(self._db) union._qs = self._qs union._all = self._all union._orderings = self._orderings