Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.2.1] - 2026-05-06

### Fixed

- Apply tree-tip ordering to chart axes that use untyped Altair
shorthand (`alt.Y("strain")` without a `:N` / `:O` type suffix).
Previously the sort override was silently skipped, so the chart
rendered in data order instead of tree order.
- An internal consistency check now raises if the spec-level walk
and the live-object walk over the user's chart ever disagree on
the number of strain-axis encodings, so silent skips like the
one above can't recur in another shape.

## [0.2.0] - 2026-05-06

### Added
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "tree-annotated-plot"
version = "0.2.0"
version = "0.2.1"
description = "Annotate the axis of an Altair / Vega-Lite plot with a phylogenetic tree."
readme = "README.md"
requires-python = ">=3.13"
Expand Down
186 changes: 119 additions & 67 deletions src/tree_annotated_plot/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import re
import warnings
from collections.abc import Iterator
from html.parser import HTMLParser
from pathlib import Path
from typing import Any, Literal
Expand Down Expand Up @@ -223,11 +224,15 @@ def _build(
legend_show=config.tree_color_legend_show,
)

new_chart = _apply_tree_order_to_chart_object(
chart, config.chart_strain_field, tip_names
)
if config.connect_leader_to_label:
_suppress_chart_strain_axis(new_chart, config.chart_strain_field)
new_chart = copy.deepcopy(chart)
suppress_axis_chrome = config.connect_leader_to_label
n_hits = 0
for ch in _iter_strain_axis_channels(new_chart, config.chart_strain_field):
ch.sort = list(tip_names)
if suppress_axis_chrome:
ch.axis = alt.Axis(labels=False, ticks=False, domain=False, title=None)
n_hits += 1
_check_walker_hits("strain-axis update", n_hits, len(axis_hits), axis)
hoisted_config, hoisted_other = _pop_toplevel_only_attrs(new_chart)

combined = _concat_for_location(
Expand Down Expand Up @@ -916,11 +921,45 @@ def walk(n: dict) -> None:
return out


def _apply_tree_order_to_chart_object(
chart: alt.TopLevelMixin, chart_strain_field: str, sort_order: list[str]
) -> alt.TopLevelMixin:
"""Return a deepcopy of the user's chart with `sort=sort_order` applied
to every axis encoding that references `chart_strain_field`.
def _check_walker_hits(operation: str, actual: int, expected: int, axis: str) -> None:
"""Cross-check the live-object walk's hit count against the spec walk's.

`_find_strain_encoding` walks the chart's serialized dict form to count
how many `x`/`y` strain encodings the chart has and to validate
consistency (axis agreement, type=nominal/ordinal, etc.). The live
iteration in `_build` (driven by `_iter_strain_axis_channels`) must
visit exactly the same number of encodings — fewer would silently skip
applying the tree's tip order (or the axis suppression) to part of the
chart, more would mean we mutated structures the dict walker doesn't
know about.

The check is symmetric (`!=`) rather than one-sided (`<`) because
either direction signals that spec-level introspection and
live-object traversal have diverged, and continuing would render an
unverified chart.
"""
if actual != expected:
raise RuntimeError(
f"internal consistency check failed for {operation!r}: "
f"_find_strain_encoding located {expected} strain {axis!r}-axis "
f"encoding(s) in the chart spec, but the live-object walk "
f"updated {actual}. Spec-level and live-object traversal have "
"diverged, which would render the chart with a wrong tip order "
"or leave axis chrome behind. Please file a bug at "
"https://github.com/jbloomlab/tree-annotated-plot/issues with a "
"minimal reproducer."
)


def _iter_strain_axis_channels(node: Any, chart_strain_field: str) -> Iterator[Any]:
"""Yield every live x/y channel object whose field matches
`chart_strain_field`.

Pure read — no mutation, no count. The caller iterates and applies
whatever mutation it needs (currently `sort` and, when
`connect_leader_to_label=True`, an axis-suppression `alt.Axis(...)`),
counting hits as it goes so the cross-check in `_check_walker_hits`
can compare against the spec walker.

Walks the live altair object tree (Chart / LayerChart / FacetChart /
HConcatChart / VConcatChart / ConcatChart) rather than its dict form,
Expand All @@ -929,62 +968,25 @@ def _apply_tree_order_to_chart_object(
dict approach has to fight. Modifying the object in place is robust as
long as altair's container attribute names hold (.hconcat / .vconcat /
.concat / .layer / .spec / .encoding), which is stable in altair 5+.
`FacetChart.spec` is recursed into unconditionally so we descend to the
inner LayerChart / Chart that actually carries the encoding (gating on
`spec.encoding is not None` would skip LayerCharts, whose encodings
live on their layers rather than at the top level).
"""
new_chart = copy.deepcopy(chart)
_walk_and_apply_sort(new_chart, chart_strain_field, sort_order)
return new_chart


def _walk_and_apply_sort(
node: Any, chart_strain_field: str, sort_order: list[str]
) -> None:
"""Recursively set sort on every encoding whose field == chart_strain_field
on a live altair chart object."""
enc = _live_attr(node, "encoding")
if enc is not None:
for channel in ("x", "y"):
ch = _live_attr(enc, channel)
if ch is not None and _channel_field(ch) == chart_strain_field:
ch.sort = list(sort_order)
yield ch
for attr in ("hconcat", "vconcat", "concat", "layer"):
sub = _live_attr(node, attr)
if isinstance(sub, list):
for s in sub:
_walk_and_apply_sort(s, chart_strain_field, sort_order)
# FacetChart.spec is the chart being faceted; recurse unconditionally so
# we descend into the inner LayerChart / Chart that actually carries the
# encoding. Gating on `spec.encoding is not None` was wrong because a
# LayerChart has no top-level encoding — its encodings live on its layers.
yield from _iter_strain_axis_channels(s, chart_strain_field)
spec = _live_attr(node, "spec")
if spec is not None:
_walk_and_apply_sort(spec, chart_strain_field, sort_order)


def _suppress_chart_strain_axis(
chart: alt.TopLevelMixin, chart_strain_field: str
) -> None:
"""Hide labels, ticks, axis line, and title on every chart strain-axis encoding.

Walks the live altair object the same way `_walk_and_apply_sort` does
and replaces the matching encoding's ``axis`` with one that suppresses
every visible bit of axis chrome. This **overrides** any user-supplied
``axis=alt.Axis(...)`` on those encodings; that's documented on the
``connect_leader_to_label`` description in `_config.py`.
"""
enc = _live_attr(chart, "encoding")
if enc is not None:
for channel in ("x", "y"):
ch = _live_attr(enc, channel)
if ch is not None and _channel_field(ch) == chart_strain_field:
ch.axis = alt.Axis(labels=False, ticks=False, domain=False, title=None)
for attr in ("hconcat", "vconcat", "concat", "layer"):
sub = _live_attr(chart, attr)
if isinstance(sub, list):
for s in sub:
_suppress_chart_strain_axis(s, chart_strain_field)
spec = _live_attr(chart, "spec")
if spec is not None:
_suppress_chart_strain_axis(spec, chart_strain_field)
yield from _iter_strain_axis_channels(spec, chart_strain_field)


def _live_attr(obj: Any, name: str) -> Any:
Expand Down Expand Up @@ -1044,21 +1046,71 @@ def _apply_combined_config(combined: alt.HConcatChart, hoisted_config: Any) -> N
def _channel_field(ch: Any) -> str | None:
"""Read the underlying `field` string from an altair channel object.

`ch.field` returns altair's `_PropertySetter` (used for fluent chaining),
not the stored field name. The stored value is reachable via `to_dict()`.
Returns the field name when the channel references a data field (either
via the `field=` keyword or via positional shorthand like
`alt.Y("strain")` / `alt.Y("strain:N")` / `alt.Y("mean(titer):Q")`).
Returns `None` when the channel has no field at all (a `value=` /
`datum=` constant encoding). Raises `ValueError` when the channel's
`_kwds` shape is unrecognized — silent fallthrough hid a real bug
where untyped shorthand axes were never reordered to match the tree.

Reads `ch._kwds` directly rather than going through `ch.to_dict()`:
altair's `to_dict()` on a bare channel raises when the shorthand has
no explicit type (e.g. `alt.Y("strain")`), because the `nominal` /
`quantitative` inference needs the chart's data context. That
exception is what the previous catch-all hid.
"""
to_dict = getattr(ch, "to_dict", None)
if not callable(to_dict):
return None
try:
d = to_dict()
except Exception:
kwds = getattr(ch, "_kwds", None)
if not isinstance(kwds, dict):
raise ValueError(
f"channel object {type(ch).__name__} has no `_kwds` mapping; "
"this isn't a recognized altair channel encoding."
)
field = kwds.get("field")
if field is not None and field is not alt.Undefined:
# `from_dict`-roundtripped channels store the field as a
# `FieldName(SchemaBase)` wrapper rather than a plain str; its
# `to_dict()` returns the raw string, while `str(...)` returns the
# repr `FieldName('x')`. Cover both.
if isinstance(field, str):
if field:
return field
elif hasattr(field, "to_dict"):
unwrapped = field.to_dict()
if isinstance(unwrapped, str) and unwrapped:
return unwrapped
raise ValueError(
f"channel field wrapper {type(field).__name__} unwrapped "
f"to {unwrapped!r}; expected a non-empty string."
)
else:
raise ValueError(
f"channel field has unexpected value {field!r} "
f"(type {type(field).__name__}); expected a string."
)
shorthand = kwds.get("shorthand")
if shorthand is None or shorthand is alt.Undefined:
return None
if isinstance(d, dict):
f = d.get("field")
if isinstance(f, str):
return f
return None
if not isinstance(shorthand, str) or not shorthand:
raise ValueError(
f"channel shorthand has unexpected value {shorthand!r}; "
"expected a string like 'strain', 'strain:N', or "
"'mean(strain):Q'."
)
# Shorthand grammar: '[<aggregate>(]<field>[)][:<type>]'.
bare = shorthand.split(":", 1)[0]
if "(" in bare:
if not bare.endswith(")"):
raise ValueError(
f"channel shorthand {shorthand!r} has an unbalanced "
"aggregate wrapper; expected 'aggregate(field)[:type]'."
)
bare = bare[bare.index("(") + 1 : -1]
if not bare:
raise ValueError(
f"channel shorthand {shorthand!r} parsed to an empty field name."
)
return bare


def _chart_strain_dim(
Expand Down
69 changes: 68 additions & 1 deletion tests/test_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pandas as pd
import pytest

from tree_annotated_plot._plot import _find_strain_encoding
from tree_annotated_plot._plot import _channel_field, _find_strain_encoding

DATA_DIR = Path(__file__).resolve().parent.parent / "examples" / "data"

Expand Down Expand Up @@ -139,3 +139,70 @@ def test_find_strain_encoding_quantitative_type_raises() -> None:
spec = _flat_chart_spec(typ="quantitative")
with pytest.raises(ValueError, match="type='quantitative'"):
_find_strain_encoding(spec, "strain")


# ---------- _channel_field: covers the "untyped shorthand" silent-skip bug.


def test_channel_field_typed_shorthand() -> None:
"""`alt.Y('strain:N')` — the form every existing test uses."""
ch = alt.Y("strain:N")
assert _channel_field(ch) == "strain"


def test_channel_field_untyped_shorthand() -> None:
"""`alt.Y('strain')` without a type — the form that triggered the
original silent-skip bug. `to_dict()` on this channel raises because
type inference needs the chart's data, but we can still recover the
field from `_kwds['shorthand']`."""
ch = alt.Y("strain")
assert _channel_field(ch) == "strain"


def test_channel_field_explicit_field_kwarg() -> None:
"""`alt.Y(field='strain', type='nominal')` — the explicit form."""
ch = alt.Y(field="strain", type="nominal")
assert _channel_field(ch) == "strain"


def test_channel_field_after_from_dict_roundtrip() -> None:
"""After `alt.Chart.from_dict(...)` (used by the CLI / JSON / HTML
chart loaders), the field is stored as a `FieldName(SchemaBase)`
wrapper, not a plain `str`. The helper must unwrap it."""
df = pd.DataFrame({"strain": ["A", "B"], "titer": [1.0, 2.0]})
chart = alt.Chart(df).mark_circle().encode(x="titer:Q", y=alt.Y("strain:N"))
roundtripped = alt.Chart.from_dict(chart.to_dict())
assert _channel_field(roundtripped.encoding.y) == "strain"


def test_channel_field_aggregate_shorthand() -> None:
"""`alt.X('mean(titer):Q')` aggregate-wrapped shorthand."""
ch = alt.X("mean(titer):Q")
assert _channel_field(ch) == "titer"


def test_channel_field_value_only_returns_none() -> None:
"""A `value=` constant encoding has no field — legitimate None,
not an error."""
ch = alt.Y(value=5)
assert _channel_field(ch) is None


def test_channel_field_unbalanced_aggregate_raises() -> None:
"""A shorthand string we can't parse must raise rather than silently
return None — silent fallthrough is what hid the original bug."""
ch = alt.Y("strain")
ch._kwds["shorthand"] = "mean(titer:Q"
with pytest.raises(ValueError, match="unbalanced aggregate"):
_channel_field(ch)


def test_channel_field_non_kwds_object_raises() -> None:
"""A non-altair object accidentally passed in must raise, not return
None — same fail-fast principle."""

class NotAChannel:
pass

with pytest.raises(ValueError, match="no `_kwds` mapping"):
_channel_field(NotAChannel())
Loading