diff --git a/CHANGELOG.rst b/CHANGELOG.rst index be14099cd..8346d83ec 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,6 +25,7 @@ Fixed - ``MigrationRecorder`` no longer emits tortoise's own ``pk`` field ``DeprecationWarning`` when applying migrations; it now builds its bookkeeping model with ``primary_key=True``. (#2203) - ``QuerySet.count()`` now matches the limited query result for the LIMIT/OFFSET edge cases: it returns ``0`` (instead of a negative number) when ``offset()`` exceeds the total row count, and ``0`` (instead of the total) for ``limit(0)``. (#2208) - Field declarations on models now resolve to their concrete type (e.g. ``CharField[str]``) in Pyright/Pylance instead of ``Field[Unknown]``; the ``Field.__new__`` type-check stub now returns ``Self``. (#2216) +- ``select_related()`` joins are now preserved by ``.values()`` and ``.values_list()``, so an annotation, filter or ordering that references the related table no longer raises a "no such column" error. (#2004) 1.1.7 ----- diff --git a/tests/test_values.py b/tests/test_values.py index 6a5ffeccf..866b13553 100644 --- a/tests/test_values.py +++ b/tests/test_values.py @@ -5,7 +5,7 @@ from tortoise.contrib import test from tortoise.contrib.test.condition import In, NotEQ from tortoise.exceptions import FieldError -from tortoise.expressions import Case, Function, Q, When +from tortoise.expressions import Case, Function, Q, RawSQL, When from tortoise.functions import Length, Trim @@ -288,3 +288,42 @@ async def test_order_by_annotation_not_in_values_list(db): .values_list("name") ) assert tournaments == [("1",), ("2",), ("3",)] + + +@pytest.mark.asyncio +async def test_select_related_join_preserved_in_values(db): + # Regression for #2004: a select_related() join must survive .values(), so an + # annotation/ordering that references the joined table still resolves instead + # of raising "no such column". + t_b = await Tournament.create(name="b") + t_a = await Tournament.create(name="a") + e1 = await Event.create(name="e1", tournament=t_b) + e2 = await Event.create(name="e2", tournament=t_a) + + events = ( + await Event.all() + .select_related("tournament") + .annotate(tournament_name=RawSQL("event__tournament.name")) + .order_by("tournament_name") + .values("event_id") + ) + # ordered by the related tournament name: "a" (e2) before "b" (e1) + assert [e["event_id"] for e in events] == [e2.event_id, e1.event_id] + + +@pytest.mark.asyncio +async def test_select_related_join_preserved_in_values_list(db): + # Regression for #2004 (the .values_list() path). + t_b = await Tournament.create(name="b") + t_a = await Tournament.create(name="a") + e1 = await Event.create(name="e1", tournament=t_b) + e2 = await Event.create(name="e2", tournament=t_a) + + events = ( + await Event.all() + .select_related("tournament") + .annotate(tournament_name=RawSQL("event__tournament.name")) + .order_by("tournament_name") + .values_list("event_id", flat=True) + ) + assert events == [e2.event_id, e1.event_id] diff --git a/tortoise/queryset.py b/tortoise/queryset.py index fe7d4a3f7..64f78931d 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -179,6 +179,17 @@ def _join_table(self, table_criterio_tuple: TableCriterionTuple) -> None: ) self._joined_tables.append(table_criterio_tuple[0]) + def _join_select_related_tables(self, lookup_expression: str) -> None: + """Add the JOINs requested by a ``select_related`` lookup without selecting + the related columns. Value queries (``.values()`` / ``.values_list()``) only + need the joins so that orderings, filters and annotations referencing the + related table resolve to a real alias instead of an unknown one (#2004). + """ + table = self.model._meta.basetable + for field in expand_lookup_expression(self.model, lookup_expression): + field = cast(RelationalField, field) + table = self._join_table_by_field(table, field.model_field_name, field) + @staticmethod def _resolve_ordering_string(ordering: str, reverse: bool = False) -> tuple[str, Order]: order_type = Order.asc @@ -698,6 +709,7 @@ def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Lite group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + select_related=self._select_related, ) def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: @@ -753,6 +765,7 @@ def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + select_related=self._select_related, ) def delete(self) -> DeleteQuery: @@ -1709,6 +1722,7 @@ class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): "_force_indexes", "_use_indexes", "_fields_to_select_sql", + "_select_related", ) def __init__( @@ -1729,6 +1743,7 @@ def __init__( group_bys: tuple[str, ...], force_indexes: set[str], use_indexes: set[str], + select_related: set[str], ) -> None: super().__init__(model, annotations) if flat and (len(fields_for_select_list) != 1): @@ -1750,6 +1765,7 @@ def __init__( self._group_bys = group_bys self._force_indexes = force_indexes self._use_indexes = use_indexes + self._select_related = select_related self._fields_to_select_sql = { *self._fields_for_select_list, *(key for key, value in self.fields.items() if value in self._fields_for_select_list), @@ -1770,6 +1786,8 @@ def _make_query(self) -> None: fields_for_select=self._fields_for_select_list, ) self.resolve_filters(self._fields_to_select_sql) + for select_related in self._select_related: + self._join_select_related_tables(select_related) if self._limit: self.query._limit = self.query._wrapper_cls(self._limit) if self._offset: @@ -1842,6 +1860,7 @@ class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): "_group_bys", "_force_indexes", "_use_indexes", + "_select_related", ) def __init__( @@ -1861,6 +1880,7 @@ def __init__( group_bys: tuple[str, ...], force_indexes: set[str], use_indexes: set[str], + select_related: set[str], ) -> None: super().__init__(model, annotations) self._fields_for_select = fields_for_select @@ -1876,6 +1896,7 @@ def __init__( self._group_bys = group_bys self._force_indexes = force_indexes self._use_indexes = use_indexes + self._select_related = select_related def _make_query(self) -> None: self._joined_tables = [] @@ -1892,6 +1913,8 @@ def _make_query(self) -> None: fields_for_select=self._fields_for_select.keys(), ) self.resolve_filters() + for select_related in self._select_related: + self._join_select_related_tables(select_related) # remove annotations that are not in fields_for_select self.query._selects = [