diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..65745d3 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,22 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added + +- `connect_leader_to_label` (default `False`): extends each strain's + dashed leader line all the way to its text label. When on, the + chart's strain-axis labels are suppressed and replacement labels + are rendered alongside the tree on its chart-facing edge. +- `strain_label_font_size` (default `10`), `strain_label_font_weight` + (default `"normal"`), and `shift_tree_loc` (default `0`) for tuning + the size, weight, and placement of the connected labels. + +## [0.1.0] - 2026-05-04 + +Initial release. diff --git a/CLAUDE.md b/CLAUDE.md index 5dbd351..f9fa2fe 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,6 +3,11 @@ - **Keep CLAUDE.md, README.md, and the docs site updated** when changing the code. `CLAUDE.md` describes programming conventions; `README.md` describes basic use; `docs/` is the user-facing reference. Don't let them drift. +- **Record user-facing changes in `CHANGELOG.md`** under the `## [Unreleased]` + section as you make them — new parameters, behavior changes, removed + features, bug fixes — in [Keep a Changelog](https://keepachangelog.com/) + format. At release time, `## [Unreleased]` is renamed to + `## [X.Y.Z] - YYYY-MM-DD` and a fresh `## [Unreleased]` is created. - **Single source of truth — `pyproject.toml`**: canonical for dependencies, supported Python version, build config, tool settings. Don't restate any of these in prose; refer to `pyproject.toml`. diff --git a/README.md b/README.md index 60cefae..7157d08 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,10 @@ pip install git+https://github.com/jbloomlab/tree-annotated-plot.git For a development checkout, see [Installation (development)](#installation-development) below. +## Changelog + +See [CHANGELOG.md](CHANGELOG.md) for the version history. + ## Notes for developing the package ### Installation (development) diff --git a/docs/examples.md b/docs/examples.md index 711c108..eeee4ba 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -97,8 +97,39 @@ With the tree panel added by `tree-annotated-plot`: The strain axis is on `y`, so `tree_annotated_plot.plot` auto-picks the **vertical layout** (`tree_location` defaults to `"left"` on a y-encoded strain): result is an `HConcatChart` with the tree on the -left, tips flush against the chart's strain labels, and a centered -scale bar at the bottom of the tree panel. +left and a centered scale bar at the bottom of the tree panel. The +chart's natural strain-axis labels are kept exactly as the +chart-builder wrote them (fonts, ticks, axis title, and all), and the +tree's dashed leader lines stop at the tree panel's chart-facing +edge. + +### Optional: connect leaders all the way to the labels + +If you'd prefer the dashed leaders to run flush into the strain +labels themselves (with no break between tip and label), set +`connect_leader_to_label=True`. This involves moving the labels off +the chart's natural axis and into the tree panel, with a few +trade-offs to be aware of: + +- The chart's strain-axis is replaced: any labels, ticks, title, + or custom `axis=...` you set on that encoding are dropped, and + replacement labels are rendered alongside the tree. +- Label widths are estimated, not measured exactly, so layout may + need tuning. The two main knobs are `strain_label_font_size` + (default 10) and `shift_tree_loc` (a manual pixel offset that + moves the tree closer to the labels). + +The example below turns on label connection, shrinks the labels to +9 pt, and uses `shift_tree_loc=60` to bring the tree flush against +them: + +![H3N2 with label connection at 9pt font](images/h3n2_combined_label_connect.svg) + +[Open the interactive chart in a new tab →](charts/h3n2_combined_label_connect.html){target="_blank"} + +CLI flags: `--connect-leader-to-label --strain-label-font-size 9 +--shift-tree-loc 60`. In Python: +`connect_leader_to_label=True, strain_label_font_size=9, shift_tree_loc=60`. ### Reproduce — command line diff --git a/scripts/generate_docs_assets.py b/scripts/generate_docs_assets.py index fe1a072..d51c13b 100644 --- a/scripts/generate_docs_assets.py +++ b/scripts/generate_docs_assets.py @@ -122,6 +122,33 @@ def _render_kikawa() -> None: ) _save_pair(out, f"{basename}_combined") + # H3N2 again, with `connect_leader_to_label=True` and a 9-point label + # font: opt-in label connection where leaders run flush into the + # labels rendered alongside the tree (the default keeps the chart's + # natural strain-axis labels untouched). + h3n2_chart = builder.make_chart( + subtype="H3N2", + chart_type="iqr", + titers=titers, + viruses=viruses, + metadata=metadata, + all_cohorts=all_cohorts, + ) + out = tree_annotated_plot.plot( + DATA_DIR / "flu-seqneut-2025to2026_H3N2.json", + h3n2_chart, + chart_strain_field="axis_label", + tree_strain_field="derived_haplotype", + branch_length="div", + tree_size=140, + scale_bar=True, + branch_length_units="substitutions", + connect_leader_to_label=True, + strain_label_font_size=9, + shift_tree_loc=60, + ) + _save_pair(out, "h3n2_combined_label_connect") + def main() -> None: """Render every example to SVG + interactive HTML under `docs/`.""" diff --git a/src/tree_annotated_plot/_config.py b/src/tree_annotated_plot/_config.py index 95e3603..2eaaac4 100644 --- a/src/tree_annotated_plot/_config.py +++ b/src/tree_annotated_plot/_config.py @@ -121,6 +121,39 @@ class PlotConfig: "those become warnings and parsing proceeds.", ] = True + connect_leader_to_label: Annotated[ + bool, + "Off (default): the chart's strain-axis labels are kept as the " + "user wrote them and dashed leader lines stop at the tree " + "panel's chart-facing edge. On: leaders extend all the way to " + "the labels — which requires moving the labels off the chart's " + "strain axis and into the tree panel, so the chart's " + "strain-axis labels, ticks, axis line, and title are SUPPRESSED " + "(any user-supplied `axis=...` is overridden) and replacement " + "labels are rendered alongside the tree. Label widths are " + "estimated; for crowded charts tune `strain_label_font_size` or " + "`shift_tree_loc`.", + ] = False + + strain_label_font_size: Annotated[ + float, + "Font size (px) for the strain text labels rendered in the tree " + "panel when `connect_leader_to_label` is on.", + ] = 10.0 + + strain_label_font_weight: Annotated[ + Literal["normal", "bold"], + "Font weight for the strain text labels rendered in the tree panel " + "when `connect_leader_to_label` is on.", + ] = "normal" + + shift_tree_loc: Annotated[ + int, + "Pixels by which to shift the tree toward (positive) or away from " + "(negative) the chart. Default 0. Has no effect when " + "connect_leader_to_label is off.", + ] = 0 + # Sidecar for Python-docstring-only prose, keyed by PlotConfig field name. # Empty by default — add an entry when a field's docstring entry needs more @@ -170,7 +203,11 @@ def _render_data_param(name: str, description: str, width: int = 75) -> str: `_render_numpy_params`. """ body = textwrap.fill( - description, width=width, initial_indent=" ", subsequent_indent=" " + description, + width=width, + initial_indent=" ", + subsequent_indent=" ", + break_on_hyphens=False, ) return f"{name}\n{body}" @@ -200,6 +237,7 @@ def _render_numpy_params( width=width, initial_indent=" ", subsequent_indent=" ", + break_on_hyphens=False, ) ) if name in extras: @@ -210,6 +248,7 @@ def _render_numpy_params( width=width, initial_indent=" ", subsequent_indent=" ", + break_on_hyphens=False, ) ) return "\n".join(chunks) diff --git a/src/tree_annotated_plot/_plot.py b/src/tree_annotated_plot/_plot.py index bd963b8..562d821 100644 --- a/src/tree_annotated_plot/_plot.py +++ b/src/tree_annotated_plot/_plot.py @@ -42,6 +42,10 @@ def plot( branch_length_units: str | None = None, prune_tree_to_chart: bool = False, strict_version: bool = True, + connect_leader_to_label: bool = False, + strain_label_font_size: float = 10.0, + strain_label_font_weight: Literal["normal", "bold"] = "normal", + shift_tree_loc: int = 0, ) -> alt.HConcatChart | alt.VConcatChart: """Return an Altair chart with a phylogenetic tree drawn alongside `chart`.""" return _build( @@ -60,6 +64,10 @@ def plot( branch_length_units=branch_length_units, prune_tree_to_chart=prune_tree_to_chart, strict_version=strict_version, + connect_leader_to_label=connect_leader_to_label, + strain_label_font_size=strain_label_font_size, + strain_label_font_weight=strain_label_font_weight, + shift_tree_loc=shift_tree_loc, ), ) @@ -167,11 +175,18 @@ def _build( scale_bar=config.scale_bar, branch_length=config.branch_length, branch_length_units=config.branch_length_units, + connect_leader_to_label=config.connect_leader_to_label, + strain_label_font_size=config.strain_label_font_size, + strain_label_font_weight=config.strain_label_font_weight, + shift_tree_loc=config.shift_tree_loc, + tip_names=tip_names, ) new_chart = _apply_tree_order_to_chart_object( chart, config.chart_strain_field, tip_names ) + if config.connect_leader_to_label: + _suppress_chart_strain_axis(new_chart, config.chart_strain_field) hoisted_config, hoisted_other = _pop_toplevel_only_attrs(new_chart) combined = _concat_for_location( @@ -897,6 +912,33 @@ def _walk_and_apply_sort( _walk_and_apply_sort(spec, chart_strain_field, sort_order) +def _suppress_chart_strain_axis( + chart: alt.TopLevelMixin, chart_strain_field: str +) -> None: + """Hide labels, ticks, axis line, and title on every chart strain-axis encoding. + + Walks the live altair object the same way `_walk_and_apply_sort` does + and replaces the matching encoding's ``axis`` with one that suppresses + every visible bit of axis chrome. This **overrides** any user-supplied + ``axis=alt.Axis(...)`` on those encodings; that's documented on the + ``connect_leader_to_label`` description in `_config.py`. + """ + enc = _live_attr(chart, "encoding") + if enc is not None: + for channel in ("x", "y"): + ch = _live_attr(enc, channel) + if ch is not None and _channel_field(ch) == chart_strain_field: + ch.axis = alt.Axis(labels=False, ticks=False, domain=False, title=None) + for attr in ("hconcat", "vconcat", "concat", "layer"): + sub = _live_attr(chart, attr) + if isinstance(sub, list): + for s in sub: + _suppress_chart_strain_axis(s, chart_strain_field) + spec = _live_attr(chart, "spec") + if spec is not None: + _suppress_chart_strain_axis(spec, chart_strain_field) + + def _live_attr(obj: Any, name: str) -> Any: """Return obj.name unless it's missing or altair's Undefined sentinel.""" v = getattr(obj, name, None) @@ -1144,6 +1186,12 @@ def _build_scale_bar_layer( return bar + text +_LABEL_PAD_PX_MIN = 4 +_LABEL_PAD_RATIO = 0.4 # `LABEL_PAD_PX = max(MIN, font_size * RATIO)` +_LABEL_CHAR_PX_RATIO = 0.6 # rough proportional sans-serif glyph-width estimate +_LABEL_HALO_RATIO = 0.6 # white-halo strokeWidth as a fraction of font_size + + def _build_tree_chart( root: _tree.TreeNode, *, @@ -1158,6 +1206,11 @@ def _build_tree_chart( scale_bar: bool = False, branch_length: str = "div", branch_length_units: str | None = None, + connect_leader_to_label: bool = False, + strain_label_font_size: float = 10.0, + strain_label_font_weight: str = "normal", + shift_tree_loc: int = 0, + tip_names: list[str] | None = None, ) -> alt.Chart: """Build the tree panel. @@ -1194,7 +1247,53 @@ def _build_tree_chart( ) branch_max = float(seg_df[["x", "x2"]].max().max()) branch_min = float(seg_df[["x", "x2"]].min().min()) - leader_df = tips_df[tips_df["x"] < branch_max].assign(x2=branch_max) + + # When connect_leader_to_label is on: + # - All leaders extend to a single point: the panel's chart-facing edge + # (`chart_edge_branch`). + # - Each label is rendered as TWO `mark_text` layers stacked at the + # same position: a white "halo" layer (white fill + thick white + # stroke) drawn first, then the visible black text on top. The halo + # follows the actual rendered glyph outline — auto-sized to the + # text — and masks the dashed leader behind each label without any + # width estimation. Vega-Lite doesn't expose `paintOrder`, so the + # two-layer trick stands in for a single text-with-halo mark. + # - `shift_tree_loc` (pixels) shrinks the strip — bringing the tree + # visually closer to the labels — by reducing the data-units between + # branch_max and chart_edge_branch. The tree's tree_size-pixel + # allocation is unchanged. + # When connect_leader_to_label is off, the chart's natural strain-axis + # labels are kept and leaders stop at branch_max (the prior behavior). + halo_px = max(2.0, strain_label_font_size * _LABEL_HALO_RATIO) + if connect_leader_to_label: + names = tip_names if tip_names is not None else list(tips_df["name"]) + max_name_len = max((len(n) for n in names), default=0) + char_px = strain_label_font_size * _LABEL_CHAR_PX_RATIO + label_pad_px = max(_LABEL_PAD_PX_MIN, strain_label_font_size * _LABEL_PAD_RATIO) + # Strip needs to fit the longest label plus `halo_px / 2` of halo + # extension on the leader-facing side, plus a small fixed pad. + label_pixel_width = label_pad_px + max_name_len * char_px + halo_px / 2 + strip_pixel_width = label_pixel_width - shift_tree_loc + if strip_pixel_width <= 0: + raise ValueError( + f"shift_tree_loc={shift_tree_loc} eliminates the label strip " + f"(estimated label_pixel_width={label_pixel_width:.1f}); " + "reduce shift_tree_loc, lower strain_label_font_size, or " + "shorten the longest strain name." + ) + branch_span = branch_max - branch_min + per_pixel = branch_span / tree_size if tree_size else 0.0 + extra_branch_units = strip_pixel_width * per_pixel + chart_edge_branch = branch_max + extra_branch_units + tips_df = tips_df.assign(x_label=chart_edge_branch) + leader_df = tips_df[tips_df["x"] < chart_edge_branch].assign( + x2=chart_edge_branch + ) + extended_tree_size = tree_size + strip_pixel_width + else: + chart_edge_branch = branch_max + leader_df = tips_df[tips_df["x"] < branch_max].assign(x2=branch_max) + extended_tree_size = tree_size # When scale_bar is on, extend the tip-axis past the last tip by # `_SCALE_BAR_EXTRA_PIXELS` and matching extra data units. Per-row pixel @@ -1233,12 +1332,18 @@ def _build_tree_chart( if strain_axis == "y": # Vertical: branch axis on chart x; tip axis with tip 0 on top. - # tree_location flips the branch domain so tips face the chart side. - branch_domain = ( - [branch_min, branch_max] - if tree_location == "left" - else [branch_max, branch_min] - ) + # When connect_leader_to_label is on, the branch domain is extended + # past `branch_max` to `chart_edge_branch` so the label strip has + # data-units to occupy; tips at `branch_max` still sit at pixel + # `tree_size` on the panel. Each label's chart-facing edge is + # anchored at `chart_edge_branch` and aligned outward (right for + # tree on the left, left for tree on the right). + if tree_location == "left": + branch_domain = [branch_min, chart_edge_branch] + text_align = "right" + else: + branch_domain = [chart_edge_branch, branch_min] + text_align = "left" branch_scale = alt.Scale(domain=branch_domain, nice=False, zero=False) # tip-axis domain[0] (at panel bottom) extends past last tip when # scale_bar=True; tip i still sits at the same on-screen pixel. @@ -1273,20 +1378,64 @@ def _build_tree_chart( tooltip=alt.Tooltip("name:N", title="strain"), ) ) + # Strain text label, drawn as two stacked layers: a white halo + # (white fill + thick white stroke) under the visible text. The + # halo auto-sizes to the rendered glyphs. + if connect_leader_to_label: + label_text_enc = dict( + x=alt.X("x_label:Q", scale=branch_scale, axis=None), + y=tip_enc, + text="name:N", + ) + layers.append( + alt.Chart(tips_df) + .mark_text( + align=text_align, + baseline="middle", + fontSize=strain_label_font_size, + fontWeight=strain_label_font_weight, + fill="white", + stroke="white", + strokeWidth=halo_px, + strokeJoin="round", + ) + .encode(**label_text_enc) + ) + layers.append( + alt.Chart(tips_df) + .mark_text( + align=text_align, + baseline="middle", + fontSize=strain_label_font_size, + fontWeight=strain_label_font_weight, + ) + .encode(**label_text_enc) + ) if scale_bar_layer is not None: layers.append(scale_bar_layer) layered = alt.layer(*layers) - layered = layered.properties(width=tree_size, height=extended_strain_dim) + layered = layered.properties( + width=extended_tree_size, height=extended_strain_dim + ) elif strain_axis == "x": # Horizontal: branch axis on chart y (Vega-Lite default has domain[1] # at the top); tip axis with tip 0 on the left. - # tree_location="top" → root at top → branch_max at bottom. - # tree_location="bottom" → root at bottom → branch_max at top. - branch_domain = ( - [branch_max, branch_min] - if tree_location == "top" - else [branch_min, branch_max] - ) + # tree_location="top" → root at top → branch_max at bottom (label + # strip at bottom, opposite the chart above). + # tree_location="bottom" → root at bottom → branch_max at top (label + # strip at top, opposite the chart below). + # Each label's chart-facing edge is anchored at `chart_edge_branch`. + # The text mark is rotated 270° (reads bottom-to-top), which maps + # pre-rotation `align="right"` to a top anchor (text extends down) + # and `align="left"` to a bottom anchor (text extends up). For tree + # on the bottom, the chart-facing edge is the panel's top → "right". + # For tree on the top, it's the panel's bottom → "left". + if tree_location == "top": + branch_domain = [chart_edge_branch, branch_min] + text_align = "left" + else: + branch_domain = [branch_min, chart_edge_branch] + text_align = "right" branch_scale = alt.Scale(domain=branch_domain, nice=False, zero=False) # tip-axis domain[1] (at panel right) extends past last tip when # scale_bar=True. @@ -1321,10 +1470,46 @@ def _build_tree_chart( tooltip=alt.Tooltip("name:N", title="strain"), ) ) + # Strain text label as halo + visible text (see vertical-layout + # comment above). + if connect_leader_to_label: + label_text_enc = dict( + y=alt.Y("x_label:Q", scale=branch_scale, axis=None), + x=tip_enc, + text="name:N", + ) + layers.append( + alt.Chart(tips_df) + .mark_text( + align=text_align, + baseline="middle", + angle=270, + fontSize=strain_label_font_size, + fontWeight=strain_label_font_weight, + fill="white", + stroke="white", + strokeWidth=halo_px, + strokeJoin="round", + ) + .encode(**label_text_enc) + ) + layers.append( + alt.Chart(tips_df) + .mark_text( + align=text_align, + baseline="middle", + angle=270, + fontSize=strain_label_font_size, + fontWeight=strain_label_font_weight, + ) + .encode(**label_text_enc) + ) if scale_bar_layer is not None: layers.append(scale_bar_layer) layered = alt.layer(*layers) - layered = layered.properties(width=extended_strain_dim, height=tree_size) + layered = layered.properties( + width=extended_strain_dim, height=extended_tree_size + ) else: raise ValueError(f"strain_axis must be 'x' or 'y', got {strain_axis!r}") diff --git a/tests/test_cli.py b/tests/test_cli.py index 3f4410d..43148d1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -40,6 +40,10 @@ def test_help_lists_auto_generated_options() -> None: "--tree-line-width", "--scale-bar / --no-scale-bar", "--strict-version / --no-strict-version", + "--connect-leader-to-label / --no-connect-leader-to-label", + "--strain-label-font-size", + "--strain-label-font-weight", + "--shift-tree-loc", ): assert opt in result.output, f"missing option in --help: {opt}" diff --git a/tests/test_label_connection.py b/tests/test_label_connection.py new file mode 100644 index 0000000..91a2940 --- /dev/null +++ b/tests/test_label_connection.py @@ -0,0 +1,559 @@ +"""Tests for `connect_leader_to_label` and the strain-label font knobs.""" + +from __future__ import annotations + +from typing import Any + +import altair as alt +import pandas as pd +import pytest + +import tree_annotated_plot + + +def _auspice() -> dict: + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + {"name": "A", "node_attrs": {"div": 0.04}}, + {"name": "B", "node_attrs": {"div": 0.05}}, + {"name": "C", "node_attrs": {"div": 0.03}}, + {"name": "D", "node_attrs": {"div": 0.06}}, + ], + }, + } + + +def _vertical_chart() -> alt.Chart: + df = pd.DataFrame({"strain": ["A", "B", "C", "D"], "titer": [1.0, 2.0, 4.0, 8.0]}) + return ( + alt.Chart(df) + .mark_circle() + .encode(x="titer:Q", y=alt.Y("strain:N")) + .properties(width=200, height=200) + ) + + +def _horizontal_chart() -> alt.Chart: + df = pd.DataFrame({"strain": ["A", "B", "C", "D"], "titer": [1.0, 2.0, 4.0, 8.0]}) + return ( + alt.Chart(df) + .mark_circle() + .encode(x=alt.X("strain:N"), y="titer:Q") + .properties(width=200, height=200) + ) + + +def _kw(): + """Required kwargs only — `connect_leader_to_label` defaults to off.""" + return dict( + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + ) + + +def _on_kw(): + """Required kwargs plus `connect_leader_to_label=True`.""" + return dict(_kw(), connect_leader_to_label=True) + + +def _references_strain(node: Any) -> bool: + """Recursively check whether `node` has any encoding referencing the + `"strain"` field on `x` or `y` (i.e. it's the user's chart panel).""" + if isinstance(node, dict): + enc = node.get("encoding") + if isinstance(enc, dict): + for ch in ("x", "y"): + chdef = enc.get(ch, {}) + if isinstance(chdef, dict) and chdef.get("field") == "strain": + return True + for v in node.values(): + if _references_strain(v): + return True + elif isinstance(node, list): + for v in node: + if _references_strain(v): + return True + return False + + +def _is_tree_panel(panel: dict) -> bool: + """The tree panel is the one that does NOT reference the strain field + on its x/y encoding (its layers encode the layout-internal x:Q / y:Q + quantitative fields). The user's chart panel references the strain + field somewhere on x or y.""" + return not _references_strain(panel) + + +def _tree_panel(out: alt.HConcatChart | alt.VConcatChart) -> dict: + d = out.to_dict() + panels = d.get("hconcat") or d.get("vconcat") + assert panels, "expected hconcat or vconcat output" + for panel in panels: + if _is_tree_panel(panel): + return panel + raise AssertionError("no tree panel found") + + +def _chart_panel(out: alt.HConcatChart | alt.VConcatChart) -> dict: + d = out.to_dict() + panels = d.get("hconcat") or d.get("vconcat") + assert panels, "expected hconcat or vconcat output" + for panel in panels: + if not _is_tree_panel(panel): + return panel + raise AssertionError("no chart panel found") + + +def _tree_layers(out: alt.HConcatChart | alt.VConcatChart) -> list[dict]: + return _tree_panel(out)["layer"] + + +def _strain_text_layers(out: alt.HConcatChart | alt.VConcatChart) -> list[dict]: + """Return all strain-label text layers (those whose `text` encoding + references `name`). With `connect_leader_to_label=True` there are two: + a white-halo shadow layer drawn first, then the visible black text.""" + out_layers = [] + for layer in _tree_layers(out): + mark = layer.get("mark") + if isinstance(mark, dict) and mark.get("type") == "text": + enc = layer.get("encoding", {}) + text_def = enc.get("text", {}) + if isinstance(text_def, dict) and text_def.get("field") == "name": + out_layers.append(layer) + return out_layers + + +def _text_layer(out: alt.HConcatChart | alt.VConcatChart) -> dict | None: + """Return the visible strain-label text layer (the second of the two + stacked text layers, i.e. the one *without* white fill/stroke). Returns + None if `connect_leader_to_label` is off.""" + for layer in _strain_text_layers(out): + mark = layer["mark"] + if mark.get("fill") != "white": + return layer + return None + + +def _halo_text_layer(out: alt.HConcatChart | alt.VConcatChart) -> dict | None: + """Return the white-halo shadow text layer (white fill + white stroke).""" + for layer in _strain_text_layers(out): + mark = layer["mark"] + if mark.get("fill") == "white" and mark.get("stroke") == "white": + return layer + return None + + +def _leader_layer(out: alt.HConcatChart | alt.VConcatChart) -> dict: + """The dashed-rule leader layer.""" + matches = [ + layer + for layer in _tree_layers(out) + if isinstance(layer.get("mark"), dict) + and layer["mark"].get("type") == "rule" + and "strokeDash" in layer["mark"] + ] + assert len(matches) == 1, f"expected exactly one leader layer, got {len(matches)}" + return matches[0] + + +def _resolve_dataset(out_dict: dict, layer: dict) -> list[dict]: + """Return inline rows for `layer`'s data — either layer.data.values or + the named dataset that altair hoists to the top-level `datasets` block.""" + data = layer.get("data") or {} + if "values" in data and isinstance(data["values"], list): + return data["values"] + name = data.get("name") + rows = out_dict.get("datasets", {}).get(name, []) + assert isinstance(rows, list), f"dataset {name!r} not a list" + return rows + + +# ---------- chart-axis suppression ---------- + + +def test_on_suppresses_chart_strain_axis() -> None: + """Setting `connect_leader_to_label=True` hides labels/ticks/domain/title + on the chart's strain-axis encoding.""" + out = tree_annotated_plot.plot(_auspice(), _vertical_chart(), **_on_kw()) + enc = _chart_panel(out)["encoding"]["y"] + axis = enc.get("axis") + assert isinstance(axis, dict), f"expected an axis dict, got {axis!r}" + assert axis.get("labels") is False + assert axis.get("ticks") is False + assert axis.get("domain") is False + assert axis.get("title") is None + + +def test_default_off_keeps_chart_strain_axis_untouched() -> None: + """The default (`connect_leader_to_label=False`) leaves the chart's + strain encoding exactly as the user wrote it — no suppressed-axis + block is injected.""" + out = tree_annotated_plot.plot( + _auspice(), + _vertical_chart(), + **_kw(), + ) + enc = _chart_panel(out)["encoding"]["y"] + # User's _vertical_chart() didn't pass a custom axis, so no axis dict + # should be present at all (or at least none with our suppression keys). + if "axis" in enc: + axis = enc["axis"] + assert axis.get("labels") is not False + assert axis.get("ticks") is not False + + +def test_default_off_keeps_horizontal_strain_axis_untouched() -> None: + """Same check, horizontal layout.""" + out = tree_annotated_plot.plot( + _auspice(), + _horizontal_chart(), + **_kw(), + ) + enc = _chart_panel(out)["encoding"]["x"] + if "axis" in enc: + axis = enc["axis"] + assert axis.get("labels") is not False + + +# ---------- tree-panel text layer ---------- + + +def test_on_adds_text_layer_to_tree_panel() -> None: + """When on, the tree panel gains a text layer whose `text` encoding + is the strain name.""" + out = tree_annotated_plot.plot(_auspice(), _vertical_chart(), **_on_kw()) + text = _text_layer(out) + assert text is not None, "expected a strain-label text layer" + assert text["mark"]["type"] == "text" + assert text["encoding"]["text"]["field"] == "name" + + +def test_default_off_does_not_add_text_layer() -> None: + """The default (off) leaves the tree panel with no `text` mark.""" + out = tree_annotated_plot.plot( + _auspice(), + _vertical_chart(), + **_kw(), + ) + assert _text_layer(out) is None + # Only the three default layers (leaders + branches + tip-circles). + assert len(_tree_layers(out)) == 3 + + +# ---------- leader endpoint extension ---------- + + +def test_leaders_share_chart_edge_endpoint_when_on() -> None: + """With `connect_leader_to_label=True`, every leader extends to the + same `chart_edge_branch` (a single shared endpoint past `branch_max`). + A white halo around each rendered text label masks the leader behind + each glyph — but in the underlying data every leader's `x2` is the + same.""" + out = tree_annotated_plot.plot(_auspice(), _vertical_chart(), **_on_kw()) + leader = _leader_layer(out) + rows = _resolve_dataset(out.to_dict(), leader) + branch_max = 0.06 + x2_values = {row["x2"] for row in rows} + assert len(x2_values) == 1, f"expected one shared x2 value, got {x2_values}" + (x2,) = x2_values + assert x2 > branch_max + 1e-9, f"expected x2 > branch_max ({branch_max}), got {x2}" + + +def test_default_leader_endpoint_stops_at_branch_max() -> None: + """The default (off) leaves leaders stopping at branch_max (the prior + behavior).""" + out = tree_annotated_plot.plot( + _auspice(), + _vertical_chart(), + **_kw(), + ) + leader = _leader_layer(out) + rows = _resolve_dataset(out.to_dict(), leader) + branch_max = 0.06 + for row in rows: + assert row["x2"] == pytest.approx(branch_max) + + +# ---------- side-of-panel placement ---------- + + +def test_vertical_left_renders_text_with_right_align() -> None: + """Vertical layout with `connect_leader_to_label=True` and default + `tree_location="left"`: tree on left, chart on right; labels flush + against the chart on the panel's right (chart-facing) edge. Text + mark uses `align="right"` and `x_label > branch_max`.""" + out = tree_annotated_plot.plot(_auspice(), _vertical_chart(), **_on_kw()) + text = _text_layer(out) + assert text is not None + assert text["mark"].get("align") == "right" + rows = _resolve_dataset(out.to_dict(), text) + branch_max = 0.06 + for row in rows: + assert row["x_label"] > branch_max + + +def test_vertical_right_renders_text_with_left_align() -> None: + """Vertical layout, `tree_location="right"`: tree on right, chart on + left; labels are flush against the chart on the panel's left + (chart-facing) edge. Text mark uses `align="left"`; `x_label` still + extends past `branch_max` because the panel was widened in the + branch-max direction; chart's natural left-side strain axis is still + suppressed.""" + out = tree_annotated_plot.plot( + _auspice(), _vertical_chart(), tree_location="right", **_on_kw() + ) + text = _text_layer(out) + assert text is not None + assert text["mark"].get("align") == "left" + rows = _resolve_dataset(out.to_dict(), text) + branch_max = 0.06 + for row in rows: + assert row["x_label"] > branch_max + # Chart's natural left-side strain axis labels are still suppressed. + enc = _chart_panel(out)["encoding"]["y"] + axis = enc.get("axis") + assert isinstance(axis, dict) + assert axis.get("labels") is False + assert axis.get("ticks") is False + + +# ---------- horizontal layout ---------- + + +def test_horizontal_bottom_uses_right_align_and_270() -> None: + """Horizontal layout with `connect_leader_to_label=True` and default + `tree_location="bottom"`: tree below chart; labels flush against the + chart on the panel's top (chart-facing) edge. With `angle=270` (text + reads bottom-to-top), `align="right"` anchors the post-rotation top + of the text at the panel's top, so labels extend downward from the + chart.""" + out = tree_annotated_plot.plot(_auspice(), _horizontal_chart(), **_on_kw()) + text = _text_layer(out) + assert text is not None + assert text["mark"].get("angle") == 270 + assert text["mark"].get("align") == "right" + + +def test_horizontal_top_uses_left_align_and_270() -> None: + """Horizontal layout, `tree_location="top"`: tree above chart; + labels are flush against the chart on the panel's bottom (chart-facing) + edge. With `angle=270`, `align="left"` anchors the post-rotation + bottom of the text at the panel's bottom, so labels extend upward + from the chart.""" + out = tree_annotated_plot.plot( + _auspice(), _horizontal_chart(), tree_location="top", **_on_kw() + ) + text = _text_layer(out) + assert text is not None + assert text["mark"].get("angle") == 270 + assert text["mark"].get("align") == "left" + + +# ---------- font knobs ---------- + + +def test_strain_label_font_size_and_weight_propagate() -> None: + out = tree_annotated_plot.plot( + _auspice(), + _vertical_chart(), + strain_label_font_size=14, + strain_label_font_weight="bold", + **_on_kw(), + ) + text = _text_layer(out) + assert text is not None + assert text["mark"].get("fontSize") == 14 + assert text["mark"].get("fontWeight") == "bold" + + +def test_strain_label_font_weight_invalid_raises() -> None: + """`strain_label_font_weight` is a Literal["normal","bold"]; an invalid + value should fail (in practice altair's schema-validation rejects it + when constructing the text mark).""" + with pytest.raises(Exception, match="(?i)fontweight|heavy"): + tree_annotated_plot.plot( + _auspice(), + _vertical_chart(), + strain_label_font_weight="heavy", # type: ignore[arg-type] + **_on_kw(), + ) + + +# ---------- multi-encoding case ---------- + + +def test_layered_chart_suppresses_every_strain_encoding() -> None: + """If the user's chart references the strain field on multiple + sub-encodings (e.g. layered LayerChart), every match should be + suppressed — same walker pattern as `_walk_and_apply_sort`.""" + df = pd.DataFrame({"strain": ["A", "B", "C", "D"], "titer": [1.0, 2.0, 4.0, 8.0]}) + base = alt.Chart(df).encode(y=alt.Y("strain:N"), x="titer:Q") + layered = (base.mark_circle() + base.mark_line()).properties(width=200, height=200) + out = tree_annotated_plot.plot(_auspice(), layered, **_on_kw()) + chart_panel = _chart_panel(out) + # Each layer's y encoding should be suppressed. + found_axes: list[Any] = [] + + def walk(node: Any) -> None: + if isinstance(node, dict): + enc = node.get("encoding") + if isinstance(enc, dict): + y = enc.get("y") + if isinstance(y, dict) and y.get("field") == "strain": + found_axes.append(y.get("axis")) + for v in node.values(): + walk(v) + elif isinstance(node, list): + for v in node: + walk(v) + + walk(chart_panel) + assert found_axes, "expected at least one strain encoding in chart panel" + for axis in found_axes: + assert isinstance(axis, dict), f"expected suppressed axis dict, got {axis!r}" + assert axis.get("labels") is False + assert axis.get("ticks") is False + + +# ---------- white-halo shadow text layer (leader mask) ---------- + + +def test_on_adds_halo_text_layer() -> None: + """When on, a white-halo shadow `mark_text` (white fill + thick white + stroke, `strokeJoin="round"` for smooth glyph corners) is drawn under + the visible label, auto-sized to the rendered glyphs.""" + out = tree_annotated_plot.plot(_auspice(), _vertical_chart(), **_on_kw()) + halo = _halo_text_layer(out) + assert halo is not None, "expected a halo (shadow) text layer" + mark = halo["mark"] + assert mark.get("fill") == "white" + assert mark.get("stroke") == "white" + assert mark.get("strokeJoin") == "round" + # strokeWidth scales with font_size; default 10 pt × 0.6 ratio → 6.0. + assert mark.get("strokeWidth") >= 5 + + +def test_halo_strokeWidth_scales_with_font_size() -> None: + out_small = tree_annotated_plot.plot( + _auspice(), _vertical_chart(), strain_label_font_size=8, **_on_kw() + ) + out_large = tree_annotated_plot.plot( + _auspice(), _vertical_chart(), strain_label_font_size=20, **_on_kw() + ) + sw_small = _halo_text_layer(out_small)["mark"]["strokeWidth"] + sw_large = _halo_text_layer(out_large)["mark"]["strokeWidth"] + assert sw_large > sw_small + + +def test_default_off_does_not_add_halo_layer() -> None: + out = tree_annotated_plot.plot( + _auspice(), + _vertical_chart(), + **_kw(), + ) + assert _halo_text_layer(out) is None + assert _strain_text_layers(out) == [] + + +def test_halo_layer_drawn_before_visible_text() -> None: + """The halo (white-fill + white-stroke) shadow must come BEFORE the + visible black text in the layer list, so the visible text draws on + top of the halo.""" + out = tree_annotated_plot.plot(_auspice(), _vertical_chart(), **_on_kw()) + layers = _tree_layers(out) + halo_idx = next( + i + for i, layer in enumerate(layers) + if isinstance(layer.get("mark"), dict) + and layer["mark"].get("type") == "text" + and layer["mark"].get("fill") == "white" + and layer["mark"].get("stroke") == "white" + ) + visible_idx = next( + i + for i, layer in enumerate(layers) + if isinstance(layer.get("mark"), dict) + and layer["mark"].get("type") == "text" + and layer["mark"].get("fill") != "white" + and layer.get("encoding", {}).get("text", {}).get("field") == "name" + ) + assert halo_idx < visible_idx, "halo shadow must be drawn before the visible text" + + +# ---------- shift_tree_loc ---------- + + +def test_shift_tree_loc_shrinks_panel() -> None: + """Positive `shift_tree_loc` reduces the label strip's pixel width, so + the tree panel is narrower.""" + out_zero = tree_annotated_plot.plot( + _auspice(), _vertical_chart(), shift_tree_loc=0, **_on_kw() + ) + out_shift = tree_annotated_plot.plot( + _auspice(), _vertical_chart(), shift_tree_loc=2, **_on_kw() + ) + w_zero = _tree_panel(out_zero)["width"] + w_shift = _tree_panel(out_shift)["width"] + assert w_shift == pytest.approx(w_zero - 2) + + +def test_shift_tree_loc_too_large_raises() -> None: + """A `shift_tree_loc` that would erase the entire label strip is a + fail-fast error (per CLAUDE.md).""" + with pytest.raises(ValueError, match="eliminates the label strip"): + tree_annotated_plot.plot( + _auspice(), + _vertical_chart(), + shift_tree_loc=999, + **_on_kw(), + ) + + +def test_shift_tree_loc_no_effect_when_off() -> None: + """When `connect_leader_to_label=False` (the default), `shift_tree_loc` + is ignored — there's no label strip to shrink.""" + out = tree_annotated_plot.plot( + _auspice(), + _vertical_chart(), + shift_tree_loc=50, + **_kw(), + ) + # tree panel width == tree_size (default 100); no strip, no shift effect. + assert _tree_panel(out)["width"] == 100 + + +# ---------- default (off) keeps user's chart axis ---------- + + +def test_default_off_keeps_user_axis_labels_intact() -> None: + """The default (`connect_leader_to_label=False`) preserves the user's + strain-axis encoding untouched — no suppressed `axis` dict is injected. + This is the documented contract.""" + df = pd.DataFrame({"strain": ["A", "B", "C", "D"], "titer": [1.0, 2.0, 4.0, 8.0]}) + user_axis = alt.Axis(title="Strain ID", labelFontWeight="bold") + chart = ( + alt.Chart(df) + .mark_circle() + .encode(x="titer:Q", y=alt.Y("strain:N", axis=user_axis)) + .properties(width=200, height=200) + ) + out = tree_annotated_plot.plot( + _auspice(), + chart, + **_kw(), + ) + enc = _chart_panel(out)["encoding"]["y"] + axis = enc.get("axis") + assert axis is not None + # user's axis intact: title and labelFontWeight preserved; we did NOT + # set labels=False / ticks=False / domain=False. + assert axis.get("title") == "Strain ID" + assert axis.get("labelFontWeight") == "bold" + assert axis.get("labels") is not False + assert axis.get("ticks") is not False diff --git a/tests/test_style.py b/tests/test_style.py index ae4df35..28046bf 100644 --- a/tests/test_style.py +++ b/tests/test_style.py @@ -48,14 +48,16 @@ def _layer_marks(layers: list[dict]) -> list[dict]: def test_defaults_yield_three_layers() -> None: - """With default styles (all three knobs > 0) and scale_bar=False, the - tree has three layers: leaders, branches, tip-circles.""" + """With default styles (all three knobs > 0), scale_bar=False, and + connect_leader_to_label=False, the tree has three layers: leaders, + branches, tip-circles.""" out = tree_annotated_plot.plot( _auspice(), _chart(), chart_strain_field="strain", tree_strain_field="name", branch_length="div", + connect_leader_to_label=False, ) assert len(_tree_layers(out)) == 3 @@ -186,6 +188,7 @@ def test_tree_node_size_zero_disables_tip_circles() -> None: tree_strain_field="name", branch_length="div", tree_node_size=0, + connect_leader_to_label=False, ) marks = _layer_marks(_tree_layers(out)) assert not any(m.get("type") == "circle" for m in marks) @@ -201,6 +204,7 @@ def test_leader_line_width_zero_disables_leader_layer() -> None: tree_strain_field="name", branch_length="div", leader_line_width=0, + connect_leader_to_label=False, ) marks = _layer_marks(_tree_layers(out)) leader_marks = [m for m in marks if m.get("type") == "rule" and "strokeDash" in m] @@ -218,6 +222,7 @@ def test_both_disabled_leaves_only_branches() -> None: branch_length="div", tree_node_size=0, leader_line_width=0, + connect_leader_to_label=False, ) assert len(_tree_layers(out)) == 1 diff --git a/tests/test_tree_location.py b/tests/test_tree_location.py index 5d6fe01..0c4ad84 100644 --- a/tests/test_tree_location.py +++ b/tests/test_tree_location.py @@ -59,10 +59,15 @@ def _horizontal_chart() -> alt.Chart: def _kw(): + # `connect_leader_to_label=False` keeps the tree panel's branch-axis + # dimension equal to `tree_size` (no label strip), which the + # panel-width/height assertions below pin exactly. The label-connection + # behavior has its own coverage in tests/test_label_connection.py. return dict( chart_strain_field="strain", tree_strain_field="name", branch_length="div", + connect_leader_to_label=False, )