From 51ed8685ee0859240995349fb740232403b69a4b Mon Sep 17 00:00:00 2001 From: jbloom Date: Wed, 6 May 2026 15:19:01 -0700 Subject: [PATCH] fix tip ordering for untyped Altair shorthand axes (0.2.1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Charts whose strain axis used positional shorthand without a type suffix — `alt.Y("strain")` instead of `alt.Y("strain:N")` — silently rendered in data order instead of tree-tip order. The field-extraction helper called `to_dict()` on the bare channel, which raises when the shorthand has no explicit type because Altair needs a chart-data context to infer it. The blanket `except Exception` returned `None`, the live-object walker found no match, and the sort override was silently skipped. Replace the to_dict-based reader with one that inspects `_kwds` directly, handles plain-str fields, untyped/typed/aggregate shorthand, and the `FieldName(SchemaBase)` wrapper that `from_dict` introduces; raise on unrecognized shapes instead of swallowing them. Add a structural fail-fast tripwire: count strain-axis encodings on the spec walk and on the live-object walk, and raise on any mismatch. Catches both this bug and any future divergence between the two surfaces. Consolidate the two near-identical live-object walkers (sort application and axis suppression) into a single read-only generator yielding live channel objects; the call site in `_build` applies sort and (conditionally) axis suppression at each yielded channel and counts hits for the tripwire. Removes ~70 lines of duplicated traversal logic. Tests: 11 new unit and end-to-end tests covering typed/untyped/ explicit/aggregate shorthand, the from_dict-roundtripped wrapper, value-only encodings (legitimate None), unparseable shorthand (raises), non-channel objects (raises), tip ordering for untyped shorthand on x and y, and the tripwire firing when the live walker misses a strain encoding the spec walker found. --- CHANGELOG.md | 13 +++ pyproject.toml | 2 +- src/tree_annotated_plot/_plot.py | 186 ++++++++++++++++++++----------- tests/test_introspection.py | 69 +++++++++++- tests/test_plot.py | 67 +++++++++++ 5 files changed, 268 insertions(+), 69 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ee23dc..55b3848 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.2.1] - 2026-05-06 + +### Fixed + +- Apply tree-tip ordering to chart axes that use untyped Altair + shorthand (`alt.Y("strain")` without a `:N` / `:O` type suffix). + Previously the sort override was silently skipped, so the chart + rendered in data order instead of tree order. +- An internal consistency check now raises if the spec-level walk + and the live-object walk over the user's chart ever disagree on + the number of strain-axis encodings, so silent skips like the + one above can't recur in another shape. + ## [0.2.0] - 2026-05-06 ### Added diff --git a/pyproject.toml b/pyproject.toml index 2347e0f..1f8bac6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tree-annotated-plot" -version = "0.2.0" +version = "0.2.1" description = "Annotate the axis of an Altair / Vega-Lite plot with a phylogenetic tree." readme = "README.md" requires-python = ">=3.13" diff --git a/src/tree_annotated_plot/_plot.py b/src/tree_annotated_plot/_plot.py index c39d894..ae4ad13 100644 --- a/src/tree_annotated_plot/_plot.py +++ b/src/tree_annotated_plot/_plot.py @@ -7,6 +7,7 @@ import math import re import warnings +from collections.abc import Iterator from html.parser import HTMLParser from pathlib import Path from typing import Any, Literal @@ -223,11 +224,15 @@ def _build( legend_show=config.tree_color_legend_show, ) - new_chart = _apply_tree_order_to_chart_object( - chart, config.chart_strain_field, tip_names - ) - if config.connect_leader_to_label: - _suppress_chart_strain_axis(new_chart, config.chart_strain_field) + new_chart = copy.deepcopy(chart) + suppress_axis_chrome = config.connect_leader_to_label + n_hits = 0 + for ch in _iter_strain_axis_channels(new_chart, config.chart_strain_field): + ch.sort = list(tip_names) + if suppress_axis_chrome: + ch.axis = alt.Axis(labels=False, ticks=False, domain=False, title=None) + n_hits += 1 + _check_walker_hits("strain-axis update", n_hits, len(axis_hits), axis) hoisted_config, hoisted_other = _pop_toplevel_only_attrs(new_chart) combined = _concat_for_location( @@ -916,11 +921,45 @@ def walk(n: dict) -> None: return out -def _apply_tree_order_to_chart_object( - chart: alt.TopLevelMixin, chart_strain_field: str, sort_order: list[str] -) -> alt.TopLevelMixin: - """Return a deepcopy of the user's chart with `sort=sort_order` applied - to every axis encoding that references `chart_strain_field`. +def _check_walker_hits(operation: str, actual: int, expected: int, axis: str) -> None: + """Cross-check the live-object walk's hit count against the spec walk's. + + `_find_strain_encoding` walks the chart's serialized dict form to count + how many `x`/`y` strain encodings the chart has and to validate + consistency (axis agreement, type=nominal/ordinal, etc.). The live + iteration in `_build` (driven by `_iter_strain_axis_channels`) must + visit exactly the same number of encodings — fewer would silently skip + applying the tree's tip order (or the axis suppression) to part of the + chart, more would mean we mutated structures the dict walker doesn't + know about. + + The check is symmetric (`!=`) rather than one-sided (`<`) because + either direction signals that spec-level introspection and + live-object traversal have diverged, and continuing would render an + unverified chart. + """ + if actual != expected: + raise RuntimeError( + f"internal consistency check failed for {operation!r}: " + f"_find_strain_encoding located {expected} strain {axis!r}-axis " + f"encoding(s) in the chart spec, but the live-object walk " + f"updated {actual}. Spec-level and live-object traversal have " + "diverged, which would render the chart with a wrong tip order " + "or leave axis chrome behind. Please file a bug at " + "https://github.com/jbloomlab/tree-annotated-plot/issues with a " + "minimal reproducer." + ) + + +def _iter_strain_axis_channels(node: Any, chart_strain_field: str) -> Iterator[Any]: + """Yield every live x/y channel object whose field matches + `chart_strain_field`. + + Pure read — no mutation, no count. The caller iterates and applies + whatever mutation it needs (currently `sort` and, when + `connect_leader_to_label=True`, an axis-suppression `alt.Axis(...)`), + counting hits as it goes so the cross-check in `_check_walker_hits` + can compare against the spec walker. Walks the live altair object tree (Chart / LayerChart / FacetChart / HConcatChart / VConcatChart / ConcatChart) rather than its dict form, @@ -929,62 +968,25 @@ def _apply_tree_order_to_chart_object( dict approach has to fight. Modifying the object in place is robust as long as altair's container attribute names hold (.hconcat / .vconcat / .concat / .layer / .spec / .encoding), which is stable in altair 5+. + `FacetChart.spec` is recursed into unconditionally so we descend to the + inner LayerChart / Chart that actually carries the encoding (gating on + `spec.encoding is not None` would skip LayerCharts, whose encodings + live on their layers rather than at the top level). """ - new_chart = copy.deepcopy(chart) - _walk_and_apply_sort(new_chart, chart_strain_field, sort_order) - return new_chart - - -def _walk_and_apply_sort( - node: Any, chart_strain_field: str, sort_order: list[str] -) -> None: - """Recursively set sort on every encoding whose field == chart_strain_field - on a live altair chart object.""" enc = _live_attr(node, "encoding") if enc is not None: for channel in ("x", "y"): ch = _live_attr(enc, channel) if ch is not None and _channel_field(ch) == chart_strain_field: - ch.sort = list(sort_order) + yield ch for attr in ("hconcat", "vconcat", "concat", "layer"): sub = _live_attr(node, attr) if isinstance(sub, list): for s in sub: - _walk_and_apply_sort(s, chart_strain_field, sort_order) - # FacetChart.spec is the chart being faceted; recurse unconditionally so - # we descend into the inner LayerChart / Chart that actually carries the - # encoding. Gating on `spec.encoding is not None` was wrong because a - # LayerChart has no top-level encoding — its encodings live on its layers. + yield from _iter_strain_axis_channels(s, chart_strain_field) spec = _live_attr(node, "spec") if spec is not None: - _walk_and_apply_sort(spec, chart_strain_field, sort_order) - - -def _suppress_chart_strain_axis( - chart: alt.TopLevelMixin, chart_strain_field: str -) -> None: - """Hide labels, ticks, axis line, and title on every chart strain-axis encoding. - - Walks the live altair object the same way `_walk_and_apply_sort` does - and replaces the matching encoding's ``axis`` with one that suppresses - every visible bit of axis chrome. This **overrides** any user-supplied - ``axis=alt.Axis(...)`` on those encodings; that's documented on the - ``connect_leader_to_label`` description in `_config.py`. - """ - enc = _live_attr(chart, "encoding") - if enc is not None: - for channel in ("x", "y"): - ch = _live_attr(enc, channel) - if ch is not None and _channel_field(ch) == chart_strain_field: - ch.axis = alt.Axis(labels=False, ticks=False, domain=False, title=None) - for attr in ("hconcat", "vconcat", "concat", "layer"): - sub = _live_attr(chart, attr) - if isinstance(sub, list): - for s in sub: - _suppress_chart_strain_axis(s, chart_strain_field) - spec = _live_attr(chart, "spec") - if spec is not None: - _suppress_chart_strain_axis(spec, chart_strain_field) + yield from _iter_strain_axis_channels(spec, chart_strain_field) def _live_attr(obj: Any, name: str) -> Any: @@ -1044,21 +1046,71 @@ def _apply_combined_config(combined: alt.HConcatChart, hoisted_config: Any) -> N def _channel_field(ch: Any) -> str | None: """Read the underlying `field` string from an altair channel object. - `ch.field` returns altair's `_PropertySetter` (used for fluent chaining), - not the stored field name. The stored value is reachable via `to_dict()`. + Returns the field name when the channel references a data field (either + via the `field=` keyword or via positional shorthand like + `alt.Y("strain")` / `alt.Y("strain:N")` / `alt.Y("mean(titer):Q")`). + Returns `None` when the channel has no field at all (a `value=` / + `datum=` constant encoding). Raises `ValueError` when the channel's + `_kwds` shape is unrecognized — silent fallthrough hid a real bug + where untyped shorthand axes were never reordered to match the tree. + + Reads `ch._kwds` directly rather than going through `ch.to_dict()`: + altair's `to_dict()` on a bare channel raises when the shorthand has + no explicit type (e.g. `alt.Y("strain")`), because the `nominal` / + `quantitative` inference needs the chart's data context. That + exception is what the previous catch-all hid. """ - to_dict = getattr(ch, "to_dict", None) - if not callable(to_dict): - return None - try: - d = to_dict() - except Exception: + kwds = getattr(ch, "_kwds", None) + if not isinstance(kwds, dict): + raise ValueError( + f"channel object {type(ch).__name__} has no `_kwds` mapping; " + "this isn't a recognized altair channel encoding." + ) + field = kwds.get("field") + if field is not None and field is not alt.Undefined: + # `from_dict`-roundtripped channels store the field as a + # `FieldName(SchemaBase)` wrapper rather than a plain str; its + # `to_dict()` returns the raw string, while `str(...)` returns the + # repr `FieldName('x')`. Cover both. + if isinstance(field, str): + if field: + return field + elif hasattr(field, "to_dict"): + unwrapped = field.to_dict() + if isinstance(unwrapped, str) and unwrapped: + return unwrapped + raise ValueError( + f"channel field wrapper {type(field).__name__} unwrapped " + f"to {unwrapped!r}; expected a non-empty string." + ) + else: + raise ValueError( + f"channel field has unexpected value {field!r} " + f"(type {type(field).__name__}); expected a string." + ) + shorthand = kwds.get("shorthand") + if shorthand is None or shorthand is alt.Undefined: return None - if isinstance(d, dict): - f = d.get("field") - if isinstance(f, str): - return f - return None + if not isinstance(shorthand, str) or not shorthand: + raise ValueError( + f"channel shorthand has unexpected value {shorthand!r}; " + "expected a string like 'strain', 'strain:N', or " + "'mean(strain):Q'." + ) + # Shorthand grammar: '[(][)][:]'. + bare = shorthand.split(":", 1)[0] + if "(" in bare: + if not bare.endswith(")"): + raise ValueError( + f"channel shorthand {shorthand!r} has an unbalanced " + "aggregate wrapper; expected 'aggregate(field)[:type]'." + ) + bare = bare[bare.index("(") + 1 : -1] + if not bare: + raise ValueError( + f"channel shorthand {shorthand!r} parsed to an empty field name." + ) + return bare def _chart_strain_dim( diff --git a/tests/test_introspection.py b/tests/test_introspection.py index 1fab38b..f12f8ce 100644 --- a/tests/test_introspection.py +++ b/tests/test_introspection.py @@ -15,7 +15,7 @@ import pandas as pd import pytest -from tree_annotated_plot._plot import _find_strain_encoding +from tree_annotated_plot._plot import _channel_field, _find_strain_encoding DATA_DIR = Path(__file__).resolve().parent.parent / "examples" / "data" @@ -139,3 +139,70 @@ def test_find_strain_encoding_quantitative_type_raises() -> None: spec = _flat_chart_spec(typ="quantitative") with pytest.raises(ValueError, match="type='quantitative'"): _find_strain_encoding(spec, "strain") + + +# ---------- _channel_field: covers the "untyped shorthand" silent-skip bug. + + +def test_channel_field_typed_shorthand() -> None: + """`alt.Y('strain:N')` — the form every existing test uses.""" + ch = alt.Y("strain:N") + assert _channel_field(ch) == "strain" + + +def test_channel_field_untyped_shorthand() -> None: + """`alt.Y('strain')` without a type — the form that triggered the + original silent-skip bug. `to_dict()` on this channel raises because + type inference needs the chart's data, but we can still recover the + field from `_kwds['shorthand']`.""" + ch = alt.Y("strain") + assert _channel_field(ch) == "strain" + + +def test_channel_field_explicit_field_kwarg() -> None: + """`alt.Y(field='strain', type='nominal')` — the explicit form.""" + ch = alt.Y(field="strain", type="nominal") + assert _channel_field(ch) == "strain" + + +def test_channel_field_after_from_dict_roundtrip() -> None: + """After `alt.Chart.from_dict(...)` (used by the CLI / JSON / HTML + chart loaders), the field is stored as a `FieldName(SchemaBase)` + wrapper, not a plain `str`. The helper must unwrap it.""" + df = pd.DataFrame({"strain": ["A", "B"], "titer": [1.0, 2.0]}) + chart = alt.Chart(df).mark_circle().encode(x="titer:Q", y=alt.Y("strain:N")) + roundtripped = alt.Chart.from_dict(chart.to_dict()) + assert _channel_field(roundtripped.encoding.y) == "strain" + + +def test_channel_field_aggregate_shorthand() -> None: + """`alt.X('mean(titer):Q')` aggregate-wrapped shorthand.""" + ch = alt.X("mean(titer):Q") + assert _channel_field(ch) == "titer" + + +def test_channel_field_value_only_returns_none() -> None: + """A `value=` constant encoding has no field — legitimate None, + not an error.""" + ch = alt.Y(value=5) + assert _channel_field(ch) is None + + +def test_channel_field_unbalanced_aggregate_raises() -> None: + """A shorthand string we can't parse must raise rather than silently + return None — silent fallthrough is what hid the original bug.""" + ch = alt.Y("strain") + ch._kwds["shorthand"] = "mean(titer:Q" + with pytest.raises(ValueError, match="unbalanced aggregate"): + _channel_field(ch) + + +def test_channel_field_non_kwds_object_raises() -> None: + """A non-altair object accidentally passed in must raise, not return + None — same fail-fast principle.""" + + class NotAChannel: + pass + + with pytest.raises(ValueError, match="no `_kwds` mapping"): + _channel_field(NotAChannel()) diff --git a/tests/test_plot.py b/tests/test_plot.py index 68feacc..c8acb46 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -110,6 +110,73 @@ def test_plot_overrides_y_sort_to_tree_tip_order(): assert user_panel_spec["encoding"]["y"]["sort"] == ["A", "B", "C", "D"] +def test_plot_overrides_sort_for_untyped_shorthand_y(): + """Regression: `alt.Y('strain')` (no `:N` suffix) used to silently + skip the sort override because `_channel_field` called `to_dict()` + on the bare channel, which raises without a chart-data context.""" + chart = ( + alt.Chart(_synthetic_df()) + .mark_line(point=True) + .encode( + x=alt.X("titer:Q", scale=alt.Scale(type="log")), + y=alt.Y("strain"), + color="serum:N", + ) + .properties(width=300, height=200) + ) + out = tree_annotated_plot.plot( + _synthetic_auspice(), + chart, + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + ) + user_panel_spec = out.hconcat[1].to_dict() + assert user_panel_spec["encoding"]["y"]["sort"] == ["A", "B", "C", "D"] + + +def test_plot_overrides_sort_for_untyped_shorthand_x(): + """Same regression on the horizontal layout.""" + chart = ( + alt.Chart(_synthetic_df()) + .mark_line(point=True) + .encode( + x=alt.X("strain"), + y=alt.Y("titer:Q", scale=alt.Scale(type="log")), + color="serum:N", + ) + .properties(width=200, height=300) + ) + out = tree_annotated_plot.plot( + _synthetic_auspice(), + chart, + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + ) + # tree_location defaults to "bottom" for x-axis strain → vconcat(user, tree) + user_panel_spec = out.vconcat[0].to_dict() + assert user_panel_spec["encoding"]["x"]["sort"] == ["A", "B", "C", "D"] + + +def test_plot_consistency_tripwire_fires_when_live_walker_misses(monkeypatch): + """If `_channel_field` ever reverts to silently failing, the spec walker + finds a strain encoding the live walker doesn't, and the count-based + consistency check raises rather than letting a wrong-ordered chart + render.""" + from tree_annotated_plot import _plot + + monkeypatch.setattr(_plot, "_channel_field", lambda ch: None) + with pytest.raises(RuntimeError, match="internal consistency check failed"): + tree_annotated_plot.plot( + _synthetic_auspice(), + _synthetic_chart(), + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + ) + + def test_plot_strain_mismatch_raises(): bad = _synthetic_chart() bad.data = bad.data.replace({"strain": {"D": "X"}})