From c7f9941eb83dc4441db145d7cca37fe6f28bb9bc Mon Sep 17 00:00:00 2001 From: Vladislav Date: Wed, 27 May 2026 17:40:49 +0300 Subject: [PATCH 1/6] feat: rewrote the method for getting m2m objects --- tests/test_prefetching.py | 26 +++++++++ tortoise/backends/base/executor.py | 93 +++++++----------------------- 2 files changed, 46 insertions(+), 73 deletions(-) diff --git a/tests/test_prefetching.py b/tests/test_prefetching.py index face651cc..492794fbd 100644 --- a/tests/test_prefetching.py +++ b/tests/test_prefetching.py @@ -154,6 +154,32 @@ async def test_prefetch_m2m_to_attr(db): assert list(event.to_attr_participants_2) == [team_second] +@pytest.mark.asyncio +async def test_prefetch_m2m_annotate(db): + tournament = await Tournament.create(name="tournament") + team = await Team.create(name="1") + event = await Event.create(name="First", tournament=tournament) + await event.participants.add(team) + event = await Event.first().prefetch_related( + Prefetch("participants", Team.annotate(count_events=Count("events"))) + ) + for team in event.participants: + assert team.count_events == 1 + + +@pytest.mark.asyncio +async def test_prefetch_m2m_select_related(db): + tournament = await Tournament.create(name="tournament") + team = await Team.create(name="1") + event = await Event.create(name="First", tournament=tournament) + await team.events.add(event) + team = await Team.first().prefetch_related( + Prefetch("events", Event.all().select_related("tournament")) + ) + for event in team.events: + assert event.tournament == tournament + + @pytest.mark.asyncio async def test_prefetch_o2o_to_attr(db): tournament = await Tournament.create(name="tournament") diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 9865b78fc..fbeffaf74 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -7,11 +7,11 @@ from copy import copy from typing import TYPE_CHECKING, Any, cast -from pypika_tortoise import JoinType, Parameter, Table +from pypika_tortoise import Parameter from pypika_tortoise.queries import QueryBuilder from tortoise.exceptions import OperationalError, UnSupportedError -from tortoise.expressions import Expression, ResolveContext +from tortoise.expressions import Expression, RawSQL, ResolveContext from tortoise.fields.base import DatabaseDefault from tortoise.fields.relational import ( BackwardFKRelation, @@ -19,7 +19,6 @@ ManyToManyFieldInstance, RelationalField, ) -from tortoise.query_utils import QueryModifier if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient @@ -602,85 +601,33 @@ async def _prefetch_m2m_relation( field: str, related_query: tuple[str | None, QuerySet], ) -> Iterable[Model]: - to_attr, related_query = related_query - instance_id_set: set = { - instance._meta.pk.to_db_value(instance.pk, instance) for instance in instance_list - } + to_attr, queryset = related_query field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore - through_table = Table(field_object.through, schema=field_object.through_schema) - - subquery = ( - self.db.query_class.from_(through_table) - .select( - through_table[field_object.backward_key].as_("_backward_relation_key"), - through_table[field_object.forward_key].as_("_forward_relation_key"), - ) - .where(through_table[field_object.backward_key].isin(instance_id_set)) - ) + through = field_object.through + if field_object.through_schema: + through = f"{field_object.through_schema}.{through}" - related_query_table = related_query.model._meta.basetable - related_pk_field = related_query.model._meta.db_pk_column - related_query.resolve_ordering(related_query.model, related_query_table, [], {}) - query = ( - related_query.query.join(subquery) - .on(subquery._forward_relation_key == related_query_table[related_pk_field]) - .select( - subquery._backward_relation_key.as_("_backward_relation_key"), - *[related_query_table[field].as_(field) for field in related_query.fields], - ) - ) + related_objects = await queryset.filter( + **{f"{field_object.related_name}__in": instance_list} + ).annotate(_backward_relation_key=RawSQL(f'"{through}"."{field_object.backward_key}"')) - if related_query._q_objects: - joined_tables: list[Table] = [] - modifier = QueryModifier() - for node in related_query._q_objects: - modifier &= node.resolve( - ResolveContext( - model=related_query.model, - table=related_query_table, - annotations=related_query._annotations, - custom_filters=related_query._custom_filters, - ) - ) - - for join in modifier.joins: - if join[0] not in joined_tables: - query = query.join(join[0], how=JoinType.left_outer).on(join[1]) - joined_tables.append(join[0]) - - if modifier.where_criterion: - query = query.where(modifier.where_criterion) - - if modifier.having_criterion: - query = query.having(modifier.having_criterion) - - _, raw_results = await self.db.execute_query(*query.get_parameterized_sql()) - relations: list[tuple[Any, Any]] = [] - related_object_list: list[Model] = [] - model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk - for e in raw_results: - pk_values: tuple[Any, Any] = ( - model_pk.to_python_value(e["_backward_relation_key"]), - related_pk.to_python_value(e[related_pk_field]), - ) - relations.append(pk_values) - related_object_list.append(related_query.model._init_from_db(**e)) await self.__class__( - model=related_query.model, db=self.db, prefetch_map=related_query._prefetch_map - )._execute_prefetch_queries(related_object_list) - related_object_map = {e.pk: e for e in related_object_list} - relation_map: dict[str, list] = {} + model=queryset.model, db=self.db, prefetch_map=queryset._prefetch_map + )._execute_prefetch_queries(related_objects) - for object_id, related_object_id in relations: - if object_id not in relation_map: - relation_map[object_id] = [] - relation_map[object_id].append(related_object_map[related_object_id]) + model_pk = self.model._meta.pk + relation_map: dict = {} + for obj in related_objects: + bk = model_pk.to_python_value(obj._backward_relation_key) + relation_map.setdefault(bk, []).append(obj) + del obj._backward_relation_key for instance in instance_list: - relation_container = getattr(instance, field) - relation_container._set_result_for_query(relation_map.get(instance.pk, []), to_attr) + getattr(instance, field)._set_result_for_query( + relation_map.get(instance.pk, []), to_attr + ) return instance_list async def _prefetch_direct_relation( From f16c511c213701a1a343caa6c190536397deea1f Mon Sep 17 00:00:00 2001 From: Vladislav Date: Thu, 28 May 2026 01:05:18 +0300 Subject: [PATCH 2/6] fixed _prefetch_m2m_relation. --- tortoise/backends/base/executor.py | 54 ++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index fbeffaf74..8c86ec3a6 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -7,11 +7,11 @@ from copy import copy from typing import TYPE_CHECKING, Any, cast -from pypika_tortoise import Parameter -from pypika_tortoise.queries import QueryBuilder +from pypika_tortoise.queries import QueryBuilder, Table +from pypika_tortoise.terms import Parameter from tortoise.exceptions import OperationalError, UnSupportedError -from tortoise.expressions import Expression, RawSQL, ResolveContext +from tortoise.expressions import Expression, ResolveContext from tortoise.fields.base import DatabaseDefault from tortoise.fields.relational import ( BackwardFKRelation, @@ -605,24 +605,44 @@ async def _prefetch_m2m_relation( field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore - through = field_object.through - if field_object.through_schema: - through = f"{field_object.through_schema}.{through}" + model_pk = self.model._meta.pk + instance_pks = [model_pk.to_db_value(instance.pk, instance) for instance in instance_list] related_objects = await queryset.filter( - **{f"{field_object.related_name}__in": instance_list} - ).annotate(_backward_relation_key=RawSQL(f'"{through}"."{field_object.backward_key}"')) - - await self.__class__( - model=queryset.model, db=self.db, prefetch_map=queryset._prefetch_map - )._execute_prefetch_queries(related_objects) + **{f"{field_object.related_name}__in": instance_pks} + ) - model_pk = self.model._meta.pk relation_map: dict = {} - for obj in related_objects: - bk = model_pk.to_python_value(obj._backward_relation_key) - relation_map.setdefault(bk, []).append(obj) - del obj._backward_relation_key + if related_objects: + related_pk_map: dict = {obj.pk: obj for obj in related_objects} + related_model_pk = queryset.model._meta.pk + related_pks = [related_model_pk.to_db_value(pk, None) for pk in related_pk_map] + through_table = Table(field_object.through, schema=field_object.through_schema) + backward_field = through_table[field_object.backward_key] + forward_field = through_table[field_object.forward_key] + + _, (_, through_rows) = await asyncio.gather( + self.__class__( + model=queryset.model, db=self.db, prefetch_map=queryset._prefetch_map + )._execute_prefetch_queries(related_objects), + self.db.execute_query( + *( + self.db.query_class.from_(through_table) + .select(backward_field, forward_field) + .where(backward_field.isin(instance_pks)) + .where(forward_field.isin(related_pks)) + .get_parameterized_sql() + ) + ), + ) + + for row in through_rows: + backward_key_value = model_pk.to_python_value(row[field_object.backward_key]) + related_object = related_pk_map.get( + related_model_pk.to_python_value(row[field_object.forward_key]) + ) + if related_object is not None: + relation_map.setdefault(backward_key_value, []).append(related_object) for instance in instance_list: getattr(instance, field)._set_result_for_query( From bbafb5ba3197788ddc3c364a545b5b82f21bfcb6 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Thu, 28 May 2026 21:40:51 +0300 Subject: [PATCH 3/6] Fix order_by, only, double prefetch in _prefetch_m2m_relation. --- tests/test_prefetching.py | 29 +++++++++++++++++ tortoise/backends/base/executor.py | 50 +++++++++++++++--------------- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/tests/test_prefetching.py b/tests/test_prefetching.py index 492794fbd..85f39413e 100644 --- a/tests/test_prefetching.py +++ b/tests/test_prefetching.py @@ -180,6 +180,35 @@ async def test_prefetch_m2m_select_related(db): assert event.tournament == tournament +@pytest.mark.asyncio +async def test_prefetch_m2m_order_by(db): + tournament = await Tournament.create(name="tournament") + team_1 = await Team.create(name="1") + team_2 = await Team.create(name="2") + event = await Event.create(name="First", tournament=tournament) + await event.participants.add(team_1, team_2) + event_1 = await Event.first().prefetch_related( + Prefetch("participants", Team.all().order_by("name")) + ) + event_2 = await Event.first().prefetch_related( + Prefetch("participants", Team.all().order_by("-name")) + ) + assert [team.name for team in event_1.participants] == ["1", "2"] + assert [team.name for team in event_2.participants] == ["2", "1"] + + +@pytest.mark.asyncio +async def test_prefetch_m2m_only(db): + tournament = await Tournament.create(name="tournament") + team = await Team.create(name="1") + event = await Event.create(name="First", tournament=tournament) + await team.events.add(event) + team = await Team.first().prefetch_related(Prefetch("events", Event.all().only("name"))) + assert len(team.events) == 1 + for event in team.events: + assert bool(event.pk) + + @pytest.mark.asyncio async def test_prefetch_o2o_to_attr(db): tournament = await Tournament.create(name="tournament") diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 8c86ec3a6..092381f4c 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -608,41 +608,41 @@ async def _prefetch_m2m_relation( model_pk = self.model._meta.pk instance_pks = [model_pk.to_db_value(instance.pk, instance) for instance in instance_list] - related_objects = await queryset.filter( - **{f"{field_object.related_name}__in": instance_pks} - ) + related_model_pk = queryset.model._meta.pk + model_field_name_pk = related_model_pk.model_field_name + fields_for_select = queryset._fields_for_select + if fields_for_select and model_field_name_pk not in fields_for_select: + queryset = queryset.only(*queryset._fields_for_select, model_field_name_pk) relation_map: dict = {} - if related_objects: - related_pk_map: dict = {obj.pk: obj for obj in related_objects} - related_model_pk = queryset.model._meta.pk - related_pks = [related_model_pk.to_db_value(pk, None) for pk in related_pk_map] + related_objects_by_pks = { + obj.pk: obj + for obj in await queryset.filter(**{f"{field_object.related_name}__in": instance_pks}) + } + if related_objects_by_pks: through_table = Table(field_object.through, schema=field_object.through_schema) backward_field = through_table[field_object.backward_key] forward_field = through_table[field_object.forward_key] - _, (_, through_rows) = await asyncio.gather( - self.__class__( - model=queryset.model, db=self.db, prefetch_map=queryset._prefetch_map - )._execute_prefetch_queries(related_objects), - self.db.execute_query( - *( - self.db.query_class.from_(through_table) - .select(backward_field, forward_field) - .where(backward_field.isin(instance_pks)) - .where(forward_field.isin(related_pks)) - .get_parameterized_sql() - ) - ), + _, through_rows = await self.db.execute_query( + *( + self.db.query_class.from_(through_table) + .select(backward_field, forward_field) + .where(backward_field.isin(instance_pks)) + .where(forward_field.isin(tuple(related_objects_by_pks))) + .get_parameterized_sql() + ) ) + reverse_map: dict = {} for row in through_rows: + forward_key_value = related_model_pk.to_python_value(row[field_object.forward_key]) backward_key_value = model_pk.to_python_value(row[field_object.backward_key]) - related_object = related_pk_map.get( - related_model_pk.to_python_value(row[field_object.forward_key]) - ) - if related_object is not None: - relation_map.setdefault(backward_key_value, []).append(related_object) + reverse_map.setdefault(forward_key_value, []).append(backward_key_value) + + for related_object in related_objects_by_pks.values(): + for instance_pk in reverse_map.get(related_object.pk, []): + relation_map.setdefault(instance_pk, []).append(related_object) for instance in instance_list: getattr(instance, field)._set_result_for_query( From 774c08158fbe7f8ca85211e004af8e35582599d8 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Fri, 29 May 2026 18:15:17 +0300 Subject: [PATCH 4/6] Fixed the _prefetch_m2m_relation method for pk equal to UUID in sqlite. --- tortoise/backends/base/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 092381f4c..f19e4caf9 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -616,7 +616,7 @@ async def _prefetch_m2m_relation( relation_map: dict = {} related_objects_by_pks = { - obj.pk: obj + related_model_pk.to_db_value(obj.pk, obj): obj for obj in await queryset.filter(**{f"{field_object.related_name}__in": instance_pks}) } if related_objects_by_pks: From b4c690bbf2fc0eb8959b07b69272166c6f38c929 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Sat, 30 May 2026 12:24:24 +0300 Subject: [PATCH 5/6] Add test prefetch_related by UUID. --- tests/fields/test_m2m_uuid.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/fields/test_m2m_uuid.py b/tests/fields/test_m2m_uuid.py index 7648750b2..f1c41b891 100644 --- a/tests/fields/test_m2m_uuid.py +++ b/tests/fields/test_m2m_uuid.py @@ -163,3 +163,15 @@ async def test__add_uninstantiated(db, m2m_uuid_models): two = await UUIDM2MRelatedModel.create() with pytest.raises(OperationalError, match=r"You should first call .save\(\) on"): await two.models.add(one) + + +@pytest.mark.asyncio +async def test_prefetch_related(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + two = await UUIDM2MRelatedModel.create() + await one.peers.add(two) + + fetched = await UUIDPkModel.get(pk=one.pk).prefetch_related("peers") + + assert list(fetched.peers) == [two] From b7fa7db554c4da999a4a18b32ea5d8502cf2f423 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Mon, 1 Jun 2026 11:08:32 +0300 Subject: [PATCH 6/6] Corrected the names of variables in the _prefetch_m2m_relation method. --- tortoise/backends/base/executor.py | 37 +++++++++++++++++------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index f19e4caf9..f84fd0e34 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -601,23 +601,29 @@ async def _prefetch_m2m_relation( field: str, related_query: tuple[str | None, QuerySet], ) -> Iterable[Model]: - to_attr, queryset = related_query + to_attr, related_queryset = related_query field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore - model_pk = self.model._meta.pk - instance_pks = [model_pk.to_db_value(instance.pk, instance) for instance in instance_list] + pk_field = self.model._meta.pk + instance_id_set: set = { + pk_field.to_db_value(instance.pk, instance) for instance in instance_list + } - related_model_pk = queryset.model._meta.pk - model_field_name_pk = related_model_pk.model_field_name - fields_for_select = queryset._fields_for_select - if fields_for_select and model_field_name_pk not in fields_for_select: - queryset = queryset.only(*queryset._fields_for_select, model_field_name_pk) + related_pk_field = related_queryset.model._meta.pk + related_pk_field_name = related_pk_field.model_field_name + fields_for_select = related_queryset._fields_for_select + if fields_for_select and related_pk_field_name not in fields_for_select: + related_queryset = related_queryset.only( + *related_queryset._fields_for_select, related_pk_field_name + ) relation_map: dict = {} related_objects_by_pks = { - related_model_pk.to_db_value(obj.pk, obj): obj - for obj in await queryset.filter(**{f"{field_object.related_name}__in": instance_pks}) + related_pk_field.to_db_value(obj.pk, obj): obj + for obj in await related_queryset.filter( + **{f"{field_object.related_name}__in": instance_id_set} + ) } if related_objects_by_pks: through_table = Table(field_object.through, schema=field_object.through_schema) @@ -628,7 +634,7 @@ async def _prefetch_m2m_relation( *( self.db.query_class.from_(through_table) .select(backward_field, forward_field) - .where(backward_field.isin(instance_pks)) + .where(backward_field.isin(instance_id_set)) .where(forward_field.isin(tuple(related_objects_by_pks))) .get_parameterized_sql() ) @@ -636,8 +642,8 @@ async def _prefetch_m2m_relation( reverse_map: dict = {} for row in through_rows: - forward_key_value = related_model_pk.to_python_value(row[field_object.forward_key]) - backward_key_value = model_pk.to_python_value(row[field_object.backward_key]) + forward_key_value = related_pk_field.to_python_value(row[field_object.forward_key]) + backward_key_value = pk_field.to_python_value(row[field_object.backward_key]) reverse_map.setdefault(forward_key_value, []).append(backward_key_value) for related_object in related_objects_by_pks.values(): @@ -645,9 +651,8 @@ async def _prefetch_m2m_relation( relation_map.setdefault(instance_pk, []).append(related_object) for instance in instance_list: - getattr(instance, field)._set_result_for_query( - relation_map.get(instance.pk, []), to_attr - ) + relation_container = getattr(instance, field) + relation_container._set_result_for_query(relation_map.get(instance.pk, []), to_attr) return instance_list async def _prefetch_direct_relation(