From a4b08830ec0ead087813529c6aa9a180ffa4601f Mon Sep 17 00:00:00 2001 From: jbloom Date: Tue, 5 May 2026 15:39:13 -0700 Subject: [PATCH] add color_tree_by parameter for Auspice-style tree coloring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a color_tree_by parameter that colors the tree's branches and tip circles by: - A node_attrs key like "subclade" — colored by node_attrs[].value on each node. - "genotype::" — colored by the inferred amino-acid (or nucleotide) state at the site, computed from branch_attrs.mutations walked from the root. - "genotype::,,..." — same but for a haplotype across sites; sites that don't vary in the tree drop out of the label; if every requested site is invariant, every node gets a single "" category. Color and ordering are chosen to match Nextstrain views for the same tree: when the Auspice JSON defines meta.colorings[].scale that palette is used; otherwise the same per-N palette Auspice's frontend uses (reproduced in _color.py with attribution to AGPL-licensed Auspice) fills in. Categories are sorted by descending frequency (ties broken alphabetically), matching Auspice's sortedDomain. The "unknown" category renders in gray and is hidden from the legend when only internal nodes lack the attribute. Other changes bundled in: - TreeNode now carries node_attrs and branch_attrs for downstream consumers; load_auspice_with_meta sibling helper exposes the JSON's top-level meta dict alongside the parsed root. - Default tree_line_width bumped from 1.5 to 2 and tree_node_size from 28 to 45 with explicit opacity=1, since the prior defaults were tuned for unicolor black trees and read poorly when colored. - New "Color the tree" subsection in docs/examples.md with a second H3N2 example colored by genotype HA1:158. Co-Authored-By: Claude Opus 4.7 --- CHANGELOG.md | 14 + docs/examples.md | 41 ++ scripts/generate_docs_assets.py | 40 +- src/tree_annotated_plot/_color.py | 469 +++++++++++++++++ src/tree_annotated_plot/_config.py | 19 +- src/tree_annotated_plot/_plot.py | 126 ++++- src/tree_annotated_plot/_tree.py | 93 +++- tests/test_color_tree.py | 800 +++++++++++++++++++++++++++++ 8 files changed, 1560 insertions(+), 42 deletions(-) create mode 100644 src/tree_annotated_plot/_color.py create mode 100644 tests/test_color_tree.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 65745d3..4391fcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `strain_label_font_size` (default `10`), `strain_label_font_weight` (default `"normal"`), and `shift_tree_loc` (default `0`) for tuning the size, weight, and placement of the connected labels. +- `color_tree_by` (default `None`): color the tree's branches and tip + circles by an Auspice node attribute (e.g. `"subclade"`) or by the + inferred genotype state at one or more sites + (e.g. `"genotype:HA1:158"` or `"genotype:HA1:158,189"`). Colors, + category ordering, and the bottom-of-plot legend match the + Nextstrain view of the same tree. + +### Changed + +- Default `tree_line_width` bumped from `1.5` to `2`, default + `tree_node_size` from `28` to `45`. Tree branch lines and tip + circles are now drawn at full opacity. The thicker / fuller + defaults read better when the tree is colored (the prior values + were tuned for unicolor black trees). ## [0.1.0] - 2026-05-04 diff --git a/docs/examples.md b/docs/examples.md index eeee4ba..c159454 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -103,6 +103,11 @@ chart-builder wrote them (fonts, ticks, axis title, and all), and the tree's dashed leader lines stop at the tree panel's chart-facing edge. +The H3N2 example above is rendered with `color_tree_by="subclade"`, +which colors the tree's branches and tip circles by the +`node_attrs.subclade` value at each node and adds a categorical legend +below the plot. See "Color the tree" below for the full set of options. + ### Optional: connect leaders all the way to the labels If you'd prefer the dashed leaders to run flush into the strain @@ -131,6 +136,40 @@ CLI flags: `--connect-leader-to-label --strain-label-font-size 9 --shift-tree-loc 60`. In Python: `connect_leader_to_label=True, strain_label_font_size=9, shift_tree_loc=60`. +### Color the tree + +Pass `color_tree_by` to color the tree's branches and tip circles by +any property the Auspice JSON exposes — broadly, anything that +appears in the "Color By" dropdown on the Nextstrain view of the same +tree. Two forms are supported: + +- A named attribute, e.g. `color_tree_by="subclade"` (used in the + example above). Common alternatives include `"clade_membership"`, + `"region"`, `"country"` — whichever the tree provides. +- A genotype at one or more sites in a gene. For a single site, + `color_tree_by="genotype:HA1:158"` colors each tip by the amino + acid at HA1 site 158. A comma-separated list gives a haplotype: + `"genotype:HA1:158,189"`. Sites that don't vary in the tree are + dropped from the haplotype label. + +Colors match what you'd see on the Nextstrain view of the same tree — +either from the JSON's palette information when the build provides it, +or from the same default palette Auspice uses when it doesn't. +Categories are ordered by descending frequency in both cases. Missing +values render in gray, and the legend is drawn at the bottom of the +combined plot. + +The example below colors the same H3N2 chart by genotype at HA1 +site 158, which has two mutations in the tree (`N158K`, `N158D`) and +so renders three states (N, K, D): + +![H3N2 combined chart, colored by genotype HA1:158](images/h3n2_combined_genotype_158.svg) + +[Open the interactive chart in a new tab →](charts/h3n2_combined_genotype_158.html){target="_blank"} + +CLI flag: `--color-tree-by genotype:HA1:158`. In Python: +`color_tree_by="genotype:HA1:158"`. + ### Reproduce — command line ```bash @@ -145,6 +184,7 @@ tree-annotated-plot \ --tree-size 140 \ --scale-bar \ --branch-length-units substitutions \ + --color-tree-by subclade \ --output examples/data/h3n2_combined.json ``` @@ -162,6 +202,7 @@ out = tree_annotated_plot.plot( tree_size=140, scale_bar=True, branch_length_units="substitutions", + color_tree_by="subclade", ) ``` diff --git a/scripts/generate_docs_assets.py b/scripts/generate_docs_assets.py index d51c13b..976ce75 100644 --- a/scripts/generate_docs_assets.py +++ b/scripts/generate_docs_assets.py @@ -110,9 +110,7 @@ def _render_kikawa() -> None: # Render the bare chart (no tree) so the docs page can show what # the chart looks like before tree-annotated-plot wraps it. _save_pair(chart, f"{basename}_chart_only") - out = tree_annotated_plot.plot( - DATA_DIR / f"flu-seqneut-2025to2026_{subtype}.json", - chart, + plot_kwargs = dict( chart_strain_field="axis_label", tree_strain_field="derived_haplotype", branch_length="div", @@ -120,6 +118,16 @@ def _render_kikawa() -> None: scale_bar=True, branch_length_units="substitutions", ) + if subtype == "H3N2": + # Color H3N2 by subclade so the docs SVG matches what users see + # on Nextstrain. The Auspice JSON's meta.colorings.subclade has + # no `scale` defined, so colors come from the default palette. + plot_kwargs["color_tree_by"] = "subclade" + out = tree_annotated_plot.plot( + DATA_DIR / f"flu-seqneut-2025to2026_{subtype}.json", + chart, + **plot_kwargs, + ) _save_pair(out, f"{basename}_combined") # H3N2 again, with `connect_leader_to_label=True` and a 9-point label @@ -146,9 +154,35 @@ def _render_kikawa() -> None: connect_leader_to_label=True, strain_label_font_size=9, shift_tree_loc=60, + color_tree_by="subclade", ) _save_pair(out, "h3n2_combined_label_connect") + # H3N2 once more, colored by genotype at HA1 site 158: same chart and + # default layout as `h3n2_combined`, with `color_tree_by` switched to + # the genotype form. Site 158 has two mutations (N158K, N158D) in the + # tree, so this renders three states (N, K, D). + h3n2_chart_genotype = builder.make_chart( + subtype="H3N2", + chart_type="iqr", + titers=titers, + viruses=viruses, + metadata=metadata, + all_cohorts=all_cohorts, + ) + out = tree_annotated_plot.plot( + DATA_DIR / "flu-seqneut-2025to2026_H3N2.json", + h3n2_chart_genotype, + chart_strain_field="axis_label", + tree_strain_field="derived_haplotype", + branch_length="div", + tree_size=140, + scale_bar=True, + branch_length_units="substitutions", + color_tree_by="genotype:HA1:158", + ) + _save_pair(out, "h3n2_combined_genotype_158") + def main() -> None: """Render every example to SVG + interactive HTML under `docs/`.""" diff --git a/src/tree_annotated_plot/_color.py b/src/tree_annotated_plot/_color.py new file mode 100644 index 0000000..87a681c --- /dev/null +++ b/src/tree_annotated_plot/_color.py @@ -0,0 +1,469 @@ +"""Resolve per-node color categories and the associated `alt.Scale` arrays. + +Public entry point: :func:`compute_node_color_values`. + +Spec mini-syntax (see `PlotConfig.color_tree_by`): + +- ``""`` — look up ``node_attrs[]`` on each node, auto-unwrapping + the Auspice ``{"value": ...}`` convention. Missing → ``"unknown"``. +- ``"genotype::"`` — single-site genotype state inferred by walking + ``branch_attrs.mutations[]`` from the root. Site with zero mutations + in the tree → all nodes are labeled ``""``. +- ``"genotype::,,..."`` — haplotype across sites. Per-node + label is the ``/``-joined letter+site for each *varying* site, in + user-supplied order. Sites that are invariant in the tree are dropped from + the label. If all requested sites are invariant, every node gets + ``""``. + +Color resolution prefers ``meta.colorings[].scale`` from the Auspice JSON +when defined (matches Nextstrain views). When unset, partial, or missing for +the requested key, the per-N table in :data:`_AUSPICE_PALETTE` fills in — the +same hand-tuned categorical palette Auspice's frontend uses for trees that +don't ship explicit color scales, so output still matches the Nextstrain +view. Categories are ordered by descending frequency (ties broken +alphabetically) before being mapped positionally onto the palette, again +matching Auspice's behavior. ``"unknown"`` always renders as ``#888888`` +(gray) and is reserved — no entry in :data:`_AUSPICE_PALETTE` is gray, so +a category drawn from the palette can never be confused with missing. +""" + +from __future__ import annotations + +import re +from collections import Counter +from dataclasses import dataclass +from typing import Iterator + +from ._tree import TreeNode + +# Auspice's frontend categorical palette, indexed by category count. +# `_AUSPICE_PALETTE[N]` is a hand-tuned `N`-color palette, perceptually spaced +# along a viridis-like sweep (purple → blue → green → yellow → orange → red); +# entries 0..36 are reproduced from Auspice's `src/util/globals.js`. We cap at +# 36 (matching Auspice's own `colors[colors.length - 1]` fallback) and reuse +# the largest palette for >36-category trees. Reproducing this table lets +# rendered output match Nextstrain's view of the same tree even when the +# Auspice JSON omits an explicit `meta.colorings[].scale`. +# +# Source: https://github.com/nextstrain/auspice/blob/master/src/util/globals.js +# Auspice is licensed under AGPL-3.0; the per-N color tables are reproduced +# here as factual color data with attribution. +# fmt: off +_AUSPICE_PALETTE: tuple[tuple[str, ...], ...] = ( + (), + ("#4C90C0",), + ("#4C90C0", "#CBB742"), + ("#4988C5", "#7EB876", "#CBB742"), + ("#4580CA", "#6BB28D", "#AABD52", "#DFA43B"), + ("#4377CD", "#61AB9D", "#94BD61", "#CDB642", "#E68133"), + ("#416DCE", "#59A3AA", "#84BA6F", "#BBBC49", "#E29D39", "#E1502A"), + ("#3F63CF", "#529AB6", "#75B681", "#A6BE55", "#D4B13F", "#E68133", "#DC2F24"), + ("#3E58CF", "#4B8EC1", "#65AE96", "#8CBB69", "#B8BC4A", "#DCAB3C", "#E67932", "#DC2F24"), + ("#3F4DCB", "#4681C9", "#5AA4A8", "#78B67E", "#9EBE5A", "#C5B945", "#E0A23A", "#E67231", "#DC2F24"), + ("#4042C7", "#4274CE", "#5199B7", "#69B091", "#88BB6C", "#ADBD51", "#CEB541", "#E39B39", "#E56C2F", "#DC2F24"), + ("#4137C2", "#4066CF", "#4B8DC2", "#5DA8A3", "#77B67F", "#96BD60", "#B8BC4B", "#D4B13F", "#E59638", "#E4672F", "#DC2F24"), + ("#462EB9", "#3E58CF", "#4580CA", "#549DB2", "#69B091", "#83BA70", "#A2BE57", "#C1BA47", "#D9AD3D", "#E69136", "#E4632E", "#DC2F24"), + ("#4B26B1", "#3F4ACA", "#4272CE", "#4D92BF", "#5DA8A3", "#74B583", "#8EBC66", "#ACBD51", "#C8B944", "#DDA93C", "#E68B35", "#E3602D", "#DC2F24"), + ("#511EA8", "#403DC5", "#4063CF", "#4785C7", "#559EB1", "#67AF94", "#7EB877", "#98BD5E", "#B4BD4C", "#CDB642", "#DFA53B", "#E68735", "#E35D2D", "#DC2F24"), + ("#511EA8", "#403AC4", "#3F5ED0", "#457FCB", "#5098B9", "#60AA9F", "#73B583", "#8BBB6A", "#A4BE56", "#BDBB48", "#D3B240", "#E19F3A", "#E68234", "#E25A2C", "#DC2F24"), + ("#511EA8", "#4138C3", "#3E59CF", "#4379CD", "#4D92BE", "#5AA5A8", "#6BB18E", "#7FB975", "#96BD5F", "#AFBD4F", "#C5B945", "#D8AE3E", "#E39B39", "#E67D33", "#E2572B", "#DC2F24"), + ("#511EA8", "#4236C1", "#3F55CE", "#4273CE", "#4A8CC2", "#569FAF", "#64AD98", "#76B680", "#8BBB6A", "#A1BE58", "#B7BC4B", "#CCB742", "#DCAB3C", "#E59638", "#E67932", "#E1552B", "#DC2F24"), + ("#511EA8", "#4335BF", "#3F51CC", "#416ECE", "#4887C6", "#529BB6", "#5FA9A0", "#6EB389", "#81B973", "#95BD61", "#AABD52", "#BFBB48", "#D1B340", "#DEA63B", "#E69237", "#E67531", "#E1522A", "#DC2F24"), + ("#511EA8", "#4333BE", "#3F4ECB", "#4169CF", "#4682C9", "#4F96BB", "#5AA5A8", "#68AF92", "#78B77D", "#8BBB6A", "#9EBE59", "#B3BD4D", "#C5B945", "#D5B03F", "#E0A23A", "#E68D36", "#E67231", "#E1502A", "#DC2F24"), + ("#511EA8", "#4432BD", "#3F4BCA", "#4065CF", "#447ECC", "#4C91BF", "#56A0AE", "#63AC9A", "#71B486", "#81BA72", "#94BD62", "#A7BE54", "#BABC4A", "#CBB742", "#D9AE3E", "#E29E39", "#E68935", "#E56E30", "#E14F2A", "#DC2F24"), + ("#511EA8", "#4531BC", "#3F48C9", "#3F61D0", "#4379CD", "#4A8CC2", "#539CB4", "#5EA9A2", "#6BB18E", "#7AB77B", "#8BBB6A", "#9CBE5B", "#AFBD4F", "#C0BA47", "#CFB541", "#DCAB3C", "#E39B39", "#E68534", "#E56B2F", "#E04D29", "#DC2F24"), + ("#511EA8", "#4530BB", "#3F46C8", "#3F5ED0", "#4375CD", "#4988C5", "#5098B9", "#5AA5A8", "#66AE95", "#73B583", "#82BA71", "#93BC62", "#A4BE56", "#B5BD4C", "#C5B945", "#D3B240", "#DEA73B", "#E59738", "#E68234", "#E4682F", "#E04C29", "#DC2F24"), + ("#511EA8", "#462FBA", "#3F44C8", "#3E5BD0", "#4270CE", "#4784C8", "#4E95BD", "#57A1AD", "#61AB9C", "#6DB38A", "#7BB879", "#8BBB6A", "#9BBE5C", "#ABBD51", "#BBBC49", "#CBB843", "#D6AF3E", "#DFA43B", "#E69537", "#E67F33", "#E4662E", "#E04A29", "#DC2F24"), + ("#511EA8", "#462EB9", "#4042C7", "#3E58CF", "#416DCE", "#4580CA", "#4C90C0", "#549DB2", "#5DA8A3", "#69B091", "#75B681", "#83BA70", "#92BC63", "#A2BE57", "#B2BD4D", "#C1BA47", "#CEB541", "#D9AD3D", "#E1A03A", "#E69136", "#E67C32", "#E4632E", "#E04929", "#DC2F24"), + ("#511EA8", "#462EB9", "#4040C6", "#3F55CE", "#4169CF", "#447DCC", "#4A8CC2", "#529AB7", "#5AA5A8", "#64AD98", "#70B487", "#7DB878", "#8BBB6A", "#99BD5D", "#A9BD53", "#B7BC4B", "#C5B945", "#D1B340", "#DCAB3C", "#E29D39", "#E68D36", "#E67932", "#E3612D", "#E04828", "#DC2F24"), + ("#511EA8", "#472DB8", "#403EC6", "#3F53CD", "#4066CF", "#4379CD", "#4989C5", "#4F97BB", "#57A1AD", "#61AA9E", "#6BB18E", "#77B67F", "#84BA70", "#92BC64", "#A0BE58", "#AFBD4F", "#BCBB49", "#CAB843", "#D4B13F", "#DEA83C", "#E39B39", "#E68A35", "#E67732", "#E35F2D", "#DF4728", "#DC2F24"), + ("#511EA8", "#472CB7", "#403DC5", "#3F50CC", "#4063CF", "#4375CD", "#4785C7", "#4D93BE", "#559EB1", "#5DA8A3", "#67AF94", "#72B485", "#7EB877", "#8BBB6A", "#98BD5E", "#A6BE55", "#B4BD4C", "#C1BA47", "#CDB642", "#D7AF3E", "#DFA53B", "#E49838", "#E68735", "#E67431", "#E35D2D", "#DF4628", "#DC2F24"), + ("#511EA8", "#482CB7", "#403BC5", "#3F4ECB", "#3F61D0", "#4272CE", "#4682C9", "#4C90C0", "#529BB5", "#5AA5A8", "#63AC9A", "#6DB28B", "#78B77D", "#84BA6F", "#91BC64", "#9EBE59", "#ACBD51", "#B9BC4A", "#C5B945", "#D0B441", "#DAAD3D", "#E0A23A", "#E59637", "#E68434", "#E67231", "#E35C2C", "#DF4528", "#DC2F24"), + ("#511EA8", "#482BB6", "#403AC4", "#3F4CCB", "#3F5ED0", "#426FCE", "#457FCB", "#4A8CC2", "#5098B9", "#58A2AC", "#60AA9F", "#69B091", "#73B583", "#7FB976", "#8BBB6A", "#97BD5F", "#A4BE56", "#B1BD4E", "#BDBB48", "#C9B843", "#D3B240", "#DCAB3C", "#E19F3A", "#E69337", "#E68234", "#E67030", "#E25A2C", "#DF4428", "#DC2F24"), + ("#511EA8", "#482BB6", "#4039C3", "#3F4ACA", "#3E5CD0", "#416CCE", "#447CCD", "#4989C4", "#4E96BC", "#559FB0", "#5DA8A4", "#66AE96", "#6FB388", "#7AB77C", "#85BA6F", "#91BC64", "#9DBE5A", "#AABD53", "#B6BD4B", "#C2BA46", "#CDB642", "#D6B03F", "#DDA83C", "#E29D39", "#E69036", "#E67F33", "#E56D30", "#E2592C", "#DF4428", "#DC2F24"), + ("#511EA8", "#482AB5", "#4138C3", "#3F48C9", "#3E59CF", "#4169CF", "#4379CD", "#4886C6", "#4D92BE", "#539CB4", "#5AA5A8", "#62AB9B", "#6BB18E", "#75B581", "#7FB975", "#8BBB6A", "#96BD5F", "#A2BE57", "#AFBD4F", "#BABC4A", "#C5B945", "#CFB541", "#D8AE3E", "#DFA63B", "#E39B39", "#E68D36", "#E67D33", "#E56B2F", "#E2572B", "#DF4328", "#DC2F24"), + ("#511EA8", "#492AB5", "#4137C2", "#3F47C9", "#3E57CE", "#4067CF", "#4376CD", "#4783C8", "#4C8FC0", "#519AB7", "#58A2AC", "#5FA9A0", "#68AF93", "#70B486", "#7BB77A", "#85BA6F", "#90BC65", "#9CBE5B", "#A8BE54", "#B3BD4D", "#BEBB48", "#C9B843", "#D2B340", "#DAAD3D", "#E0A33B", "#E49838", "#E68B35", "#E67B32", "#E5692F", "#E2562B", "#DF4227", "#DC2F24"), + ("#511EA8", "#492AB5", "#4236C1", "#3F45C8", "#3F55CE", "#4064CF", "#4273CE", "#4681CA", "#4A8CC2", "#4F97BA", "#569FAF", "#5CA7A4", "#64AD98", "#6DB28B", "#76B680", "#80B974", "#8BBB6A", "#96BD60", "#A1BE58", "#ACBD51", "#B7BC4B", "#C2BA46", "#CCB742", "#D4B13F", "#DCAB3C", "#E1A13A", "#E59638", "#E68835", "#E67932", "#E4672F", "#E1552B", "#DF4227", "#DC2F24"), + ("#511EA8", "#4929B4", "#4235C0", "#3F44C8", "#3F53CD", "#3F62CF", "#4270CE", "#457ECB", "#4989C4", "#4E95BD", "#549DB3", "#5AA5A8", "#61AB9C", "#69B090", "#72B485", "#7BB879", "#85BA6E", "#90BC65", "#9BBE5C", "#A6BE55", "#B1BD4E", "#BBBC49", "#C5B945", "#CEB541", "#D6AF3E", "#DDA93C", "#E29F39", "#E69537", "#E68634", "#E67732", "#E4662E", "#E1532B", "#DF4127", "#DC2F24"), + ("#511EA8", "#4929B4", "#4335BF", "#3F42C7", "#3F51CC", "#3F60D0", "#416ECE", "#447CCD", "#4887C6", "#4D92BF", "#529BB6", "#58A2AB", "#5FA9A0", "#66AE95", "#6EB389", "#77B67E", "#81B973", "#8BBB6A", "#95BD61", "#A0BE59", "#AABD52", "#B5BD4C", "#BFBB48", "#C9B843", "#D1B340", "#D8AE3E", "#DEA63B", "#E29C39", "#E69237", "#E68434", "#E67531", "#E4642E", "#E1522A", "#DF4127", "#DC2F24"), + ("#511EA8", "#4928B4", "#4334BF", "#4041C7", "#3F50CC", "#3F5ED0", "#416CCE", "#4379CD", "#4784C7", "#4B8FC1", "#5098B9", "#56A0AF", "#5CA7A4", "#63AC99", "#6BB18E", "#73B583", "#7CB878", "#86BB6E", "#90BC65", "#9ABD5C", "#A4BE56", "#AFBD4F", "#B9BC4A", "#C2BA46", "#CCB742", "#D3B240", "#DAAC3D", "#DFA43B", "#E39B39", "#E68F36", "#E68234", "#E67431", "#E4632E", "#E1512A", "#DF4027", "#DC2F24"), +) +# fmt: on + +_UNKNOWN = "unknown" +_NO_VARIATION = "" +_GRAY = "#888888" + +# Auspice mutation strings are like "N158K" / "*123A" / "-456N": one non-digit +# char, then digits, then one non-digit char. +_MUTATION_RE = re.compile(r"^(\D)(\d+)(\D)$") + + +@dataclass(frozen=True) +class ColorMapping: + """Resolved color information for a single `color_tree_by` invocation.""" + + values_by_node: dict[str, str] + domain: list[str] + range_: list[str] + legend_title: str + # When None, the legend shows the full domain. When set, it restricts the + # legend display without altering the scale — used to hide ``"unknown"`` + # when only internal nodes (not tips) lack the attribute, since the gray + # entry in that case just flags internal-node bookkeeping rather than + # any missing tip-level data. + legend_values: list[str] | None = None + + +def compute_node_color_values( + root: TreeNode, + color_spec: str, + auspice_meta: dict | None = None, +) -> ColorMapping: + """Walk the tree and resolve per-node color categories + scale arrays. + + Parameters + ---------- + root + The root of the tree to color. Internal-node identity is by + ``TreeNode.name`` (Auspice's ``NODE_xxxx``); tips by their resolved + strain name. + color_spec + The user-supplied spec string (see module docstring). + auspice_meta + The Auspice JSON's top-level ``meta`` dict, or ``None`` when no JSON + is available (caller passed a pre-built `TreeNode`). Used only to + consult ``meta.colorings[].scale`` and ``.title`` for node-attr + specs; ignored for genotype specs. + + Returns + ------- + ColorMapping + ``values_by_node[node.name]`` is the category string for each node. + ``domain`` and ``range_`` are parallel lists for ``alt.Scale``. + ``legend_title`` is the resolved legend header. ``legend_values``, + when set, restricts which categories appear in the legend (leaving + the scale untouched). + """ + parsed = _parse_color_spec(color_spec) + if parsed[0] == "attr": + _, key = parsed + values_by_node = _color_by_node_attr(root, key) + else: + _, gene, sites = parsed + values_by_node = _color_by_genotype(root, gene, sites) + + categories = _ordered_categories(values_by_node.values()) + domain, range_ = _resolve_scale(categories, parsed, auspice_meta) + legend_title = _resolve_legend_title(color_spec, parsed, auspice_meta) + legend_values = _resolve_legend_values(domain, values_by_node, root) + return ColorMapping( + values_by_node=values_by_node, + domain=domain, + range_=range_, + legend_title=legend_title, + legend_values=legend_values, + ) + + +def _resolve_legend_values( + domain: list[str], + values_by_node: dict[str, str], + root: TreeNode, +) -> list[str] | None: + """Decide whether to hide ``"unknown"`` from the legend. + + Returns ``None`` when the full domain should appear — either because + ``"unknown"`` isn't a category at all, or because at least one tip + carries it (in which case the user looking at a gray tip needs the + legend to explain it). When ``"unknown"`` is present but only on + internal nodes, returns the domain with ``"unknown"`` filtered out; + internal segments still render gray via the unchanged scale, but the + legend doesn't dangle a misleading entry. + """ + if _UNKNOWN not in domain: + return None + for node in _walk_nodes(root): + if node.is_tip and values_by_node.get(node.name) == _UNKNOWN: + return None + return [c for c in domain if c != _UNKNOWN] + + +def _parse_color_spec( + spec: str, +) -> tuple[str, str] | tuple[str, str, list[int]]: + """Parse the spec mini-syntax. See module docstring for the grammar.""" + if not isinstance(spec, str) or not spec or any(c.isspace() for c in spec): + raise ValueError( + f"color_tree_by={spec!r}: expected either a node_attrs key, " + '"genotype::", or ' + '"genotype::,,...".' + ) + if ":" not in spec: + return ("attr", spec) + parts = spec.split(":") + if len(parts) != 3 or parts[0] != "genotype" or not parts[1] or not parts[2]: + raise ValueError( + f"color_tree_by={spec!r}: expected either a node_attrs key, " + '"genotype::", or ' + '"genotype::,,...".' + ) + _, gene, site_str = parts + site_strs = site_str.split(",") + sites: list[int] = [] + for s in site_strs: + try: + n = int(s) + except ValueError as e: + raise ValueError( + f"color_tree_by={spec!r}: site {s!r} must be a positive integer." + ) from e + if n <= 0: + raise ValueError( + f"color_tree_by={spec!r}: site {s!r} must be a positive integer." + ) + sites.append(n) + seen: set[int] = set() + dups: set[int] = set() + for n in sites: + if n in seen: + dups.add(n) + else: + seen.add(n) + if dups: + raise ValueError( + f"color_tree_by={spec!r}: site list has duplicates ({sorted(dups)}); " + "each site must appear at most once." + ) + return ("genotype", gene, sites) + + +def _walk_nodes(root: TreeNode) -> Iterator[TreeNode]: + """Yield every node in the tree in pre-order.""" + yield root + for c in root.children: + yield from _walk_nodes(c) + + +def _color_by_node_attr(root: TreeNode, key: str) -> dict[str, str]: + """Resolve per-node category by reading `node_attrs[key]`.""" + values: dict[str, str] = {} + found = False + observed_keys: set[str] = set() + for node in _walk_nodes(root): + observed_keys.update(node.node_attrs.keys()) + attr = node.node_attrs.get(key) + if attr is None: + values[node.name] = _UNKNOWN + continue + if isinstance(attr, dict) and "value" in attr: + attr = attr["value"] + if attr is None or attr == "": + values[node.name] = _UNKNOWN + continue + found = True + values[node.name] = str(attr) + if not found: + observed = ", ".join(repr(k) for k in sorted(observed_keys)) + raise ValueError( + f"color_tree_by={key!r} is not a node_attrs key in this tree. " + f"Observed keys: [{observed}]" + ) + return values + + +def _color_by_genotype(root: TreeNode, gene: str, sites: list[int]) -> dict[str, str]: + """Resolve per-node category by walking branch_attrs.mutations[gene]. + + See :func:`_color_by_genotype_single_site` for the per-site walk. + Per-node label is the ``/``-joined ```` for each varying + site in user-supplied order; if every requested site is invariant, every + node gets ``""``. + """ + observed_genes: set[str] = set() + any_mutation = False + for node in _walk_nodes(root): + gene_map = node.branch_attrs.get("mutations", {}) or {} + if gene_map: + any_mutation = True + observed_genes.update(gene_map.keys()) + if not any_mutation: + raise ValueError( + f"color_tree_by='genotype:{gene}:{','.join(map(str, sites))}': " + "the Auspice JSON has no branch_attrs.mutations annotations." + ) + if gene not in observed_genes: + observed = ", ".join(repr(g) for g in sorted(observed_genes)) + raise ValueError( + f"color_tree_by='genotype:{gene}:{','.join(map(str, sites))}': " + f"gene {gene!r} not in branch_attrs.mutations. " + f"Observed genes: [{observed}]" + ) + + per_site_states: dict[int, dict[str, str] | None] = {} + for site in sites: + per_site_states[site] = _color_by_genotype_single_site(root, gene, site) + + varying_sites = [s for s in sites if per_site_states[s] is not None] + if not varying_sites: + # Every requested site is invariant in the tree. + return {node.name: _NO_VARIATION for node in _walk_nodes(root)} + + values: dict[str, str] = {} + for node in _walk_nodes(root): + parts = [] + for site in varying_sites: + states = per_site_states[site] + assert states is not None + parts.append(f"{states[node.name]}{site}") + values[node.name] = "/".join(parts) + return values + + +def _color_by_genotype_single_site( + root: TreeNode, gene: str, site: int +) -> dict[str, str] | None: + """Pre-order walk inferring the state at (gene, site) for every node. + + Returns ``None`` when the site has zero mutations in the tree (invariant — + the caller drops it from the haplotype label, or labels every node + ``""`` when *all* requested sites are invariant). + Otherwise returns ``{node.name: letter}``. + """ + # Pass 1: collect every (gene, site) mutation along with the path-from-root + # depth at which it was applied. The earliest mutation's "from" letter is + # the root's state at the site. + mutations_seen: list[tuple[int, str, str]] = [] # (depth, from_letter, to_letter) + + def collect(node: TreeNode, depth: int) -> None: + for mut in node.branch_attrs.get("mutations", {}).get(gene, []) or []: + m = _MUTATION_RE.match(mut) + if m is None: + continue + from_letter, site_str, to_letter = m.group(1), m.group(2), m.group(3) + if int(site_str) == site: + mutations_seen.append((depth, from_letter, to_letter)) + for c in node.children: + collect(c, depth + 1) + + collect(root, 0) + if not mutations_seen: + return None + + mutations_seen.sort(key=lambda t: t[0]) + root_letter = mutations_seen[0][1] + + # Pass 2: assign per-node state. + states: dict[str, str] = {} + + def assign(node: TreeNode, current: str) -> None: + for mut in node.branch_attrs.get("mutations", {}).get(gene, []) or []: + m = _MUTATION_RE.match(mut) + if m is None: + continue + site_str, to_letter = m.group(2), m.group(3) + if int(site_str) == site: + current = to_letter + states[node.name] = current + for c in node.children: + assign(c, current) + + assign(root, root_letter) + return states + + +def _ordered_categories(values: Iterator[str] | list[str]) -> list[str]: + """Return unique category labels in legend display order. + + Real categories are sorted by descending count (ties broken + alphabetically), matching Auspice's `sortedDomain` for non- + clade-membership traits — the most common category lands at index 0 + so it gets the first slot of :data:`_AUSPICE_PALETTE[N]`. + ``"unknown"`` (if present) is pinned to the end regardless of count. + ``""`` is only ever the sole category (when every + requested site is invariant); in that case it stands alone. + """ + counts = Counter(values) + if set(counts) == {_NO_VARIATION}: + return [_NO_VARIATION] + has_unknown = _UNKNOWN in counts + real = [c for c in counts if c != _UNKNOWN] + real.sort(key=lambda c: (-counts[c], c)) + if has_unknown: + real.append(_UNKNOWN) + return real + + +def _resolve_scale( + categories: list[str], + parsed_spec: tuple, + auspice_meta: dict | None, +) -> tuple[list[str], list[str]]: + """Build (domain, range_) parallel arrays for `alt.Scale`. + + Prefers ``meta.colorings[].scale`` for node-attr specs. Categories + not covered by the auspice scale (and every category for genotype specs + or when no `auspice_meta` is supplied) fall back to + :data:`_AUSPICE_PALETTE[K]`, where K is the number of *unmapped* + categories — so the fallback hues come from the same per-N palette + Auspice uses, capped at 36. ``"unknown"`` always maps to + :data:`_GRAY`. + """ + auspice_map: dict[str, str] = {} + if parsed_spec[0] == "attr" and auspice_meta is not None: + key = parsed_spec[1] + for c in auspice_meta.get("colorings", []) or []: + if c.get("key") == key: + scale = c.get("scale") or [] + for entry in scale: + # Each entry is [value, color]; skip malformed rows. + if ( + isinstance(entry, (list, tuple)) + and len(entry) == 2 + and isinstance(entry[0], str) + and isinstance(entry[1], str) + ): + auspice_map[entry[0]] = entry[1] + break + + real_categories = [c for c in categories if c != _UNKNOWN] + unmapped = [c for c in real_categories if c not in auspice_map] + palette_idx = min(len(unmapped), len(_AUSPICE_PALETTE) - 1) + fallback_palette = _AUSPICE_PALETTE[palette_idx] + + domain: list[str] = [] + range_: list[str] = [] + fallback_pos = 0 + for cat in categories: + if cat == _UNKNOWN: + domain.append(cat) + range_.append(_GRAY) + continue + domain.append(cat) + if cat in auspice_map: + range_.append(auspice_map[cat]) + else: + if fallback_palette: + range_.append(fallback_palette[fallback_pos % len(fallback_palette)]) + else: + range_.append(_GRAY) + fallback_pos += 1 + return domain, range_ + + +def _resolve_legend_title( + color_spec: str, + parsed_spec: tuple, + auspice_meta: dict | None, +) -> str: + """Use `meta.colorings[].title` for attr specs when present; else + the literal spec string.""" + if parsed_spec[0] == "attr" and auspice_meta is not None: + key = parsed_spec[1] + for c in auspice_meta.get("colorings", []) or []: + if c.get("key") == key: + title = c.get("title") + if isinstance(title, str) and title: + return title + break + return color_spec diff --git a/src/tree_annotated_plot/_config.py b/src/tree_annotated_plot/_config.py index 2eaaac4..e827f41 100644 --- a/src/tree_annotated_plot/_config.py +++ b/src/tree_annotated_plot/_config.py @@ -70,15 +70,15 @@ class PlotConfig: tree_line_width: Annotated[ float, - "Stroke width (px) for the tree's branch lines. Default 1.5.", - ] = 1.5 + "Stroke width (px) for the tree's branch lines. Default 2.", + ] = 2.0 tree_node_size: Annotated[ float, "Area (px²) of the small filled circles drawn at each tip. " - "Default 28. Setting tree_node_size=0 disables the tip-circle " + "Default 45. Setting tree_node_size=0 disables the tip-circle " "layer entirely.", - ] = 28 + ] = 45 leader_line_width: Annotated[ float, @@ -154,6 +154,17 @@ class PlotConfig: "connect_leader_to_label is off.", ] = 0 + color_tree_by: Annotated[ + str | None, + "Color the tree by an Auspice attribute. Pass a node_attrs key " + '(e.g. "subclade"), or "genotype::" / ' + '"genotype::,,..." (e.g. "genotype:HA1:158" ' + 'or "genotype:HA1:158,189") to color by the inferred genotype ' + "state at a site or haplotype across sites. A categorical legend " + "is drawn below the plot; missing values are gray. None (default) " + "leaves the tree black.", + ] = None + # Sidecar for Python-docstring-only prose, keyed by PlotConfig field name. # Empty by default — add an entry when a field's docstring entry needs more diff --git a/src/tree_annotated_plot/_plot.py b/src/tree_annotated_plot/_plot.py index 562d821..25b4a74 100644 --- a/src/tree_annotated_plot/_plot.py +++ b/src/tree_annotated_plot/_plot.py @@ -14,7 +14,7 @@ import altair as alt import pandas as pd -from . import _config, _tree +from . import _color, _config, _tree from ._config import PlotConfig, TreeLocation # Accepted chart input forms for the public `plot` function. @@ -35,8 +35,8 @@ def plot( branch_length: Literal["div", "num_date"], tree_size: int = 100, tree_location: TreeLocation | None = None, - tree_line_width: float = 1.5, - tree_node_size: float = 28, + tree_line_width: float = 2.0, + tree_node_size: float = 45, leader_line_width: float = 1.0, scale_bar: bool = False, branch_length_units: str | None = None, @@ -46,6 +46,7 @@ def plot( strain_label_font_size: float = 10.0, strain_label_font_weight: Literal["normal", "bold"] = "normal", shift_tree_loc: int = 0, + color_tree_by: str | None = None, ) -> alt.HConcatChart | alt.VConcatChart: """Return an Altair chart with a phylogenetic tree drawn alongside `chart`.""" return _build( @@ -68,6 +69,7 @@ def plot( strain_label_font_size=strain_label_font_size, strain_label_font_weight=strain_label_font_weight, shift_tree_loc=shift_tree_loc, + color_tree_by=color_tree_by, ), ) @@ -121,7 +123,7 @@ def _build( set on a config object); `config` carries every styling / behavior knob. Both surfaces converge here so they can never disagree. """ - root = _ensure_tree( + root, auspice_meta = _ensure_tree( tree, config.tree_strain_field, branch_length=config.branch_length, @@ -162,6 +164,12 @@ def _build( strain_dim = _coerce_dim( _chart_strain_dim(spec, axis_hits, axis), n_tips=len(tip_names) ) + if config.color_tree_by is not None: + color_mapping = _color.compute_node_color_values( + root, config.color_tree_by, auspice_meta=auspice_meta + ) + else: + color_mapping = None tree_chart = _build_tree_chart( root, n_tips=len(tip_names), @@ -180,6 +188,7 @@ def _build( strain_label_font_weight=config.strain_label_font_weight, shift_tree_loc=config.shift_tree_loc, tip_names=tip_names, + color_mapping=color_mapping, ) new_chart = _apply_tree_order_to_chart_object( @@ -256,10 +265,17 @@ def _ensure_tree( *, branch_length: str, strict_version: bool, -) -> _tree.TreeNode: +) -> tuple[_tree.TreeNode, dict | None]: + """Return ``(root, auspice_meta)``. + + ``auspice_meta`` is the loaded Auspice JSON's top-level ``meta`` dict, or + ``None`` when the caller passed a pre-built ``TreeNode`` (in which case + we have no JSON to read ``meta.colorings`` from, and color resolution + falls back to the default palette). + """ if isinstance(tree, _tree.TreeNode): - return tree - return _tree.load_auspice( + return tree, None + return _tree.load_auspice_with_meta( tree, tree_strain_field=tree_strain_field, branch_length=branch_length, @@ -1200,8 +1216,8 @@ def _build_tree_chart( strain_dim: int | float | alt.Step, strain_axis: str, tree_location: TreeLocation, - tree_line_width: float = 1.5, - tree_node_size: float = 28, + tree_line_width: float = 2.0, + tree_node_size: float = 45, leader_line_width: float = 1.0, scale_bar: bool = False, branch_length: str = "div", @@ -1211,6 +1227,7 @@ def _build_tree_chart( strain_label_font_weight: str = "normal", shift_tree_loc: int = 0, tip_names: list[str] | None = None, + color_mapping: _color.ColorMapping | None = None, ) -> alt.Chart: """Build the tree panel. @@ -1248,6 +1265,37 @@ def _build_tree_chart( branch_max = float(seg_df[["x", "x2"]].max().max()) branch_min = float(seg_df[["x", "x2"]].min().min()) + # When color_tree_by is set, attach the per-node color category to both + # frames. The `mark_rule` for branches (one row per `seg_df` segment) and + # the `mark_circle` for tips share the same color encoding so Altair + # collapses them into a single legend at the bottom. + if color_mapping is not None: + seg_df = seg_df.assign( + color_value=seg_df["color_node"].map(color_mapping.values_by_node) + ) + tips_df = tips_df.assign( + color_value=tips_df["name"].map(color_mapping.values_by_node) + ) + legend_kwargs: dict = { + "title": color_mapping.legend_title, + "orient": "bottom", + } + if color_mapping.legend_values is not None: + # Restrict the legend display without touching the scale, so + # internal-node segments still render gray when "unknown" is on + # the tree but no tip is. + legend_kwargs["values"] = list(color_mapping.legend_values) + color_enc = alt.Color( + "color_value:N", + scale=alt.Scale( + domain=list(color_mapping.domain), + range=list(color_mapping.range_), + ), + legend=alt.Legend(**legend_kwargs), + ) + else: + color_enc = None + # When connect_leader_to_label is on: # - All leaders extend to a single point: the panel's chart-facing edge # (`chart_edge_branch`). @@ -1363,20 +1411,34 @@ def _build_tree_chart( ) .encode(x=branch_enc, x2="x2:Q", y=tip_enc) ) + seg_enc_kwargs: dict = { + "x": branch_enc, + "x2": "x2:Q", + "y": tip_enc, + "y2": "y2:Q", + } + if color_enc is not None: + seg_enc_kwargs["color"] = color_enc layers.append( alt.Chart(seg_df) - .mark_rule(strokeWidth=tree_line_width) - .encode(x=branch_enc, x2="x2:Q", y=tip_enc, y2="y2:Q") + .mark_rule(strokeWidth=tree_line_width, opacity=1.0) + .encode(**seg_enc_kwargs) ) if tree_node_size > 0: + tip_enc_kwargs: dict = { + "x": branch_enc, + "y": tip_enc, + "tooltip": alt.Tooltip("name:N", title="strain"), + } + tip_mark_kwargs: dict = {"size": tree_node_size, "opacity": 1.0} + if color_enc is None: + tip_mark_kwargs["color"] = "black" + else: + tip_enc_kwargs["color"] = color_enc layers.append( alt.Chart(tips_df) - .mark_circle(size=tree_node_size, color="black") - .encode( - x=branch_enc, - y=tip_enc, - tooltip=alt.Tooltip("name:N", title="strain"), - ) + .mark_circle(**tip_mark_kwargs) + .encode(**tip_enc_kwargs) ) # Strain text label, drawn as two stacked layers: a white halo # (white fill + thick white stroke) under the visible text. The @@ -1455,20 +1517,34 @@ def _build_tree_chart( ) .encode(y=branch_enc, y2="x2:Q", x=tip_enc) ) + seg_enc_kwargs = { + "y": branch_enc, + "y2": "x2:Q", + "x": tip_enc, + "x2": "y2:Q", + } + if color_enc is not None: + seg_enc_kwargs["color"] = color_enc layers.append( alt.Chart(seg_df) - .mark_rule(strokeWidth=tree_line_width) - .encode(y=branch_enc, y2="x2:Q", x=tip_enc, x2="y2:Q") + .mark_rule(strokeWidth=tree_line_width, opacity=1.0) + .encode(**seg_enc_kwargs) ) if tree_node_size > 0: + tip_enc_kwargs = { + "y": branch_enc, + "x": tip_enc, + "tooltip": alt.Tooltip("name:N", title="strain"), + } + tip_mark_kwargs = {"size": tree_node_size, "opacity": 1.0} + if color_enc is None: + tip_mark_kwargs["color"] = "black" + else: + tip_enc_kwargs["color"] = color_enc layers.append( alt.Chart(tips_df) - .mark_circle(size=tree_node_size, color="black") - .encode( - y=branch_enc, - x=tip_enc, - tooltip=alt.Tooltip("name:N", title="strain"), - ) + .mark_circle(**tip_mark_kwargs) + .encode(**tip_enc_kwargs) ) # Strain text label as halo + visible text (see vertical-layout # comment above). diff --git a/src/tree_annotated_plot/_tree.py b/src/tree_annotated_plot/_tree.py index f499cba..e6cd8d1 100644 --- a/src/tree_annotated_plot/_tree.py +++ b/src/tree_annotated_plot/_tree.py @@ -23,12 +23,20 @@ class TreeNode: `y` is set by :func:`layout` and is the integer index of the node's tip (for tips) or the midpoint of its descendants' tip indices (for internal nodes). + + `node_attrs` and `branch_attrs` carry the raw Auspice JSON dicts unchanged + so downstream code (e.g. `_color`) can read them directly. `node_attrs.div` + has already been consumed by `_resolve_branch_length`; `node_attrs[X]` is + where things like `subclade` live; `branch_attrs.mutations[]` is + where mutation strings like `"N158K"` live. """ name: str x: float children: list["TreeNode"] = field(default_factory=list) y: float | None = None + node_attrs: dict = field(default_factory=dict) + branch_attrs: dict = field(default_factory=dict) @property def is_tip(self) -> bool: @@ -66,6 +74,29 @@ def load_auspice( ``False`` the same case becomes a ``warnings.warn``. A missing ``version`` field always warns and proceeds. """ + root, _ = load_auspice_with_meta( + source, + tree_strain_field=tree_strain_field, + branch_length=branch_length, + strict_version=strict_version, + ) + return root + + +def load_auspice_with_meta( + source: str | Path | dict, + *, + tree_strain_field: str, + branch_length: str = "div", + strict_version: bool = True, +) -> tuple[TreeNode, dict]: + """Like :func:`load_auspice`, but also returns the Auspice top-level + ``meta`` dict. + + Used by code paths that need ``meta.colorings`` (e.g. resolving the + color palette for ``color_tree_by``). Returns ``(root, meta)`` where + ``meta`` is ``data.get("meta", {})``. + """ _validate_tree_strain_field(tree_strain_field) if branch_length not in ("div", "num_date"): raise ValueError( @@ -85,7 +116,9 @@ def load_auspice( if "tree" not in data: raise ValueError("Auspice JSON must have a top-level 'tree' field") - return _parse_node(data["tree"], tree_strain_field, branch_length) + root = _parse_node(data["tree"], tree_strain_field, branch_length) + meta = data.get("meta", {}) or {} + return root, meta def _check_auspice_version(data: dict, *, strict_version: bool) -> None: @@ -160,7 +193,13 @@ def _parse_node(d: dict, tree_strain_field: str, branch_length: str) -> TreeNode children = [ _parse_node(c, tree_strain_field, branch_length) for c in children_dicts ] - return TreeNode(name=name, x=float(branch_value), children=children) + return TreeNode( + name=name, + x=float(branch_value), + children=children, + node_attrs=d.get("node_attrs", {}) or {}, + branch_attrs=d.get("branch_attrs", {}) or {}, + ) def _resolve_branch_length(node_dict: dict, branch_length: str) -> float: @@ -225,7 +264,13 @@ def _prune_recursive(node: TreeNode, keep_strains: set[str]) -> TreeNode | None: """ if node.is_tip: if node.name in keep_strains: - return TreeNode(name=node.name, x=node.x, children=[]) + return TreeNode( + name=node.name, + x=node.x, + children=[], + node_attrs=node.node_attrs, + branch_attrs=node.branch_attrs, + ) return None new_children: list[TreeNode] = [] @@ -239,7 +284,13 @@ def _prune_recursive(node: TreeNode, keep_strains: set[str]) -> TreeNode | None: if len(new_children) == 1: # Collapse this single-child internal into its child. return new_children[0] - return TreeNode(name=node.name, x=node.x, children=new_children) + return TreeNode( + name=node.name, + x=node.x, + children=new_children, + node_attrs=node.node_attrs, + branch_attrs=node.branch_attrs, + ) def tips(root: TreeNode) -> Iterator[TreeNode]: @@ -278,12 +329,20 @@ def _assign_internal_y(node: TreeNode) -> float: def segments(root: TreeNode) -> pd.DataFrame: """Build a DataFrame of line segments for drawing the tree. - Returns columns `x`, `x2`, `y`, `y2`. Each row is one segment: + Returns columns `x`, `x2`, `y`, `y2`, `color_node`. Each row is one + segment: - For each internal node, one vertical connector from the topmost to the - bottommost child (x == x2 == node.x). + bottommost child (x == x2 == node.x). `color_node` is the parent's + name — the connector takes the parent's color. - For each non-root node, one horizontal branch from its parent's x to its - own x at its own y (y == y2). + own x at its own y (y == y2). `color_node` is the *child*'s name — + the branch into a node is colored by that node, matching how Auspice + colors trees. + + `color_node` is the join key into the per-node color map produced by + `_color.compute_node_color_values`. When no coloring is requested, the + column is harmless (consumers ignore it). Assumes :func:`layout` has already been called on `root`. """ @@ -297,11 +356,25 @@ def walk(node: TreeNode) -> None: return child_ys = [c.y for c in node.children] rows.append( - {"x": node.x, "x2": node.x, "y": min(child_ys), "y2": max(child_ys)} + { + "x": node.x, + "x2": node.x, + "y": min(child_ys), + "y2": max(child_ys), + "color_node": node.name, + } ) for c in node.children: - rows.append({"x": node.x, "x2": c.x, "y": c.y, "y2": c.y}) + rows.append( + { + "x": node.x, + "x2": c.x, + "y": c.y, + "y2": c.y, + "color_node": c.name, + } + ) walk(c) walk(root) - return pd.DataFrame(rows, columns=["x", "x2", "y", "y2"]) + return pd.DataFrame(rows, columns=["x", "x2", "y", "y2", "color_node"]) diff --git a/tests/test_color_tree.py b/tests/test_color_tree.py new file mode 100644 index 0000000..0f6e79d --- /dev/null +++ b/tests/test_color_tree.py @@ -0,0 +1,800 @@ +"""Tests for `color_tree_by`: per-node attr / genotype / haplotype coloring, +Auspice-scale preference, default-palette fallback, gray-for-missing, +legend wiring.""" + +from __future__ import annotations + +from typing import Any + +import altair as alt +import pandas as pd +import pytest + +import tree_annotated_plot +from tree_annotated_plot import _color, _tree + + +def _attr_auspice() -> dict: + """Tiny tree with `subclade` on every node and tips A..D in two clades.""" + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0, "subclade": {"value": "X"}}, + "children": [ + { + "name": "INT_LEFT", + "node_attrs": {"div": 0.02, "subclade": {"value": "X"}}, + "children": [ + { + "name": "A", + "node_attrs": { + "div": 0.04, + "subclade": {"value": "X"}, + }, + }, + { + "name": "B", + "node_attrs": { + "div": 0.05, + "subclade": {"value": "X"}, + }, + }, + ], + }, + { + "name": "INT_RIGHT", + "node_attrs": {"div": 0.03, "subclade": {"value": "Y"}}, + "children": [ + { + "name": "C", + "node_attrs": { + "div": 0.06, + "subclade": {"value": "Y"}, + }, + }, + { + "name": "D", + "node_attrs": { + "div": 0.07, + "subclade": {"value": "Z"}, + }, + }, + ], + }, + ], + }, + } + + +def _genotype_auspice(*, mutations_at_158: bool = True) -> dict: + """Tree with HA1 mutations along selected branches. + + With `mutations_at_158=True` (default): + - tip_A inherits N158K via its branch (state K). + - tip_B is on a no-mutation branch (root state N). + - INT1 carries N158D, so its descendants tip_C, tip_D are state D. + + With `mutations_at_158=False`: same topology, but no HA1:158 mutation + anywhere — used for the invariant-site case. + """ + site_158_mut = "N158K" if mutations_at_158 else None + int1_mut = "N158D" if mutations_at_158 else None + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + { + "name": "tip_A", + "node_attrs": {"div": 0.04}, + "branch_attrs": ( + {"mutations": {"HA1": [site_158_mut]}} if site_158_mut else {} + ), + }, + {"name": "tip_B", "node_attrs": {"div": 0.05}}, + { + "name": "INT1", + "node_attrs": {"div": 0.02}, + "branch_attrs": ( + {"mutations": {"HA1": [int1_mut]}} if int1_mut else {} + ), + "children": [ + {"name": "tip_C", "node_attrs": {"div": 0.06}}, + {"name": "tip_D", "node_attrs": {"div": 0.07}}, + ], + }, + ], + }, + } + + +def _haplotype_auspice() -> dict: + """Tree carrying HA1 mutations at sites 158 *and* 189 along independent + branches, so a 2-site haplotype gives several distinct categories.""" + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + { + "name": "tip_A", + "node_attrs": {"div": 0.04}, + "branch_attrs": {"mutations": {"HA1": ["N158K"]}}, + }, + { + "name": "tip_B", + "node_attrs": {"div": 0.05}, + "branch_attrs": {"mutations": {"HA1": ["S189T"]}}, + }, + { + "name": "INT1", + "node_attrs": {"div": 0.02}, + "branch_attrs": {"mutations": {"HA1": ["N158K", "S189T"]}}, + "children": [ + {"name": "tip_C", "node_attrs": {"div": 0.06}}, + {"name": "tip_D", "node_attrs": {"div": 0.07}}, + ], + }, + ], + }, + } + + +def _load(d: dict) -> _tree.TreeNode: + return _tree.load_auspice(d, tree_strain_field="name", branch_length="div") + + +# ----------------------------------------------------------------------------- +# node_attrs path +# ----------------------------------------------------------------------------- + + +def test_color_tree_by_node_attr_assigns_per_tip(): + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + assert m.values_by_node["A"] == "X" + assert m.values_by_node["B"] == "X" + assert m.values_by_node["C"] == "Y" + assert m.values_by_node["D"] == "Z" + + +def test_color_tree_by_node_attr_internal_nodes(): + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + assert m.values_by_node["INT_LEFT"] == "X" + assert m.values_by_node["INT_RIGHT"] == "Y" + assert m.values_by_node["ROOT"] == "X" + + +def test_color_tree_by_node_attr_unwraps_value(): + # `node_attrs.subclade = {"value": "X"}` resolves to "X", not the dict. + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + assert "X" in m.values_by_node.values() + + +def test_color_tree_by_node_attr_missing_marked_unknown(): + d = _attr_auspice() + # Strip subclade off a single tip. + d["tree"]["children"][0]["children"][0]["node_attrs"].pop("subclade") + root = _load(d) + m = _color.compute_node_color_values(root, "subclade") + assert m.values_by_node["A"] == "unknown" + # Domain places "unknown" last and pairs it with #888888. + assert m.domain[-1] == "unknown" + assert m.range_[-1] == "#888888" + # A is a tip, so the legend must keep "unknown" visible. + assert m.legend_values is None + + +def test_color_tree_by_unknown_omitted_from_legend_when_only_internal(): + """When only internal nodes lack the attribute (every tip is annotated), + "unknown" stays in the scale (so internal segments render gray) but is + hidden from the legend.""" + d = _attr_auspice() + # Strip subclade off the ROOT internal node only — all tips remain + # annotated. + d["tree"]["node_attrs"].pop("subclade") + root = _load(d) + m = _color.compute_node_color_values(root, "subclade") + assert m.values_by_node["ROOT"] == "unknown" + # Domain still includes "unknown" so the seg_df row for ROOT renders gray. + assert "unknown" in m.domain + # But the legend display drops it. + assert m.legend_values is not None + assert "unknown" not in m.legend_values + # And the rest of the legend matches domain-minus-unknown. + assert m.legend_values == [c for c in m.domain if c != "unknown"] + + +def test_color_tree_by_unknown_absent_legend_unrestricted(): + """Fully annotated tree (no `"unknown"` anywhere) -> legend_values stays + None, i.e. the legend uses the full domain.""" + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + assert "unknown" not in m.values_by_node.values() + assert m.legend_values is None + + +def test_color_tree_by_missing_attr_raises_lists_keys(): + root = _load(_attr_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "nonexistent_field") + msg = str(exc.value) + assert "nonexistent_field" in msg + # Observed keys appear in the message — div and subclade at minimum. + assert "'div'" in msg + assert "'subclade'" in msg + + +# ----------------------------------------------------------------------------- +# genotype path: single-site +# ----------------------------------------------------------------------------- + + +def test_color_tree_by_genotype_root_state_inference(): + root = _load(_genotype_auspice()) + m = _color.compute_node_color_values(root, "genotype:HA1:158") + assert m.values_by_node["tip_A"] == "K158" # branch carries N158K + assert m.values_by_node["tip_B"] == "N158" # root state + assert m.values_by_node["tip_C"] == "D158" # inherited from INT1's N158D + assert m.values_by_node["tip_D"] == "D158" + + +def test_color_tree_by_genotype_multiple_mutations_along_path(): + d = { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + {"name": "tip_X", "node_attrs": {"div": 0.05}}, + { + "name": "INT1", + "node_attrs": {"div": 0.02}, + "branch_attrs": {"mutations": {"HA1": ["N158K"]}}, + "children": [ + { + "name": "tip_Y", + "node_attrs": {"div": 0.04}, + }, + { + "name": "tip_Z", + "node_attrs": {"div": 0.06}, + "branch_attrs": {"mutations": {"HA1": ["K158R"]}}, + }, + ], + }, + ], + }, + } + root = _load(d) + m = _color.compute_node_color_values(root, "genotype:HA1:158") + assert m.values_by_node["tip_X"] == "N158" # root, no muts on path + assert m.values_by_node["tip_Y"] == "K158" # parent N158K + assert m.values_by_node["tip_Z"] == "R158" # N158K then K158R + + +def test_color_tree_by_genotype_single_site_invariant_renders_no_variation(): + root = _load(_genotype_auspice(mutations_at_158=False)) + # The JSON has no mutations anywhere, so this fires the "no mutation + # annotations" error rather than the invariant path. Add a stray + # mutation at a different site so mutations *exist* in the tree but + # NOT at site 158. + d = _genotype_auspice(mutations_at_158=False) + d["tree"]["children"][0]["branch_attrs"] = {"mutations": {"HA1": ["S145N"]}} + root = _load(d) + m = _color.compute_node_color_values(root, "genotype:HA1:158") + # Every node gets the literal "" category. + assert set(m.values_by_node.values()) == {""} + assert m.domain == [""] + + +def test_color_tree_by_genotype_no_mutations_anywhere_raises(): + d = _genotype_auspice(mutations_at_158=False) + root = _load(d) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "genotype:HA1:158") + assert "no branch_attrs.mutations annotations" in str(exc.value) + + +def test_color_tree_by_genotype_missing_gene_raises_lists_genes(): + root = _load(_genotype_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "genotype:NONEXISTENT:158") + msg = str(exc.value) + assert "'NONEXISTENT'" in msg + assert "'HA1'" in msg + + +# ----------------------------------------------------------------------------- +# genotype path: haplotype +# ----------------------------------------------------------------------------- + + +def test_color_tree_by_haplotype_basic(): + root = _load(_haplotype_auspice()) + m = _color.compute_node_color_values(root, "genotype:HA1:158,189") + # tip_A: branch N158K, no 189 mut -> K158/S189 + # tip_B: branch S189T, no 158 mut -> N158/T189 + # INT1's children inherit both N158K and S189T -> K158/T189 + assert m.values_by_node["tip_A"] == "K158/S189" + assert m.values_by_node["tip_B"] == "N158/T189" + assert m.values_by_node["tip_C"] == "K158/T189" + assert m.values_by_node["tip_D"] == "K158/T189" + + +def test_color_tree_by_haplotype_drops_invariant_sites(): + # Same tree as the single-site test, but ask for a 2-site haplotype + # where only 158 has mutations: the haplotype label collapses to just + # the 158 locus. + d = _genotype_auspice() # only 158 has mutations + root = _load(d) + m = _color.compute_node_color_values(root, "genotype:HA1:158,189") + # 189 is invariant in this tree -> dropped from the label. + assert m.values_by_node["tip_A"] == "K158" + assert m.values_by_node["tip_B"] == "N158" + assert m.values_by_node["tip_C"] == "D158" + + +def test_color_tree_by_haplotype_all_invariant_renders_no_variation(): + d = _genotype_auspice(mutations_at_158=False) + # Need mutations to exist somewhere so we don't trip the + # "no mutations anywhere" guard. Add a HA1 mutation at a third site. + d["tree"]["children"][0]["branch_attrs"] = {"mutations": {"HA1": ["S145N"]}} + root = _load(d) + m = _color.compute_node_color_values(root, "genotype:HA1:158,189") + assert set(m.values_by_node.values()) == {""} + assert m.domain == [""] + + +def test_color_tree_by_haplotype_preserves_user_site_order(): + root = _load(_haplotype_auspice()) + m = _color.compute_node_color_values(root, "genotype:HA1:189,158") + # User wrote 189 first, so the label has 189 first. + assert m.values_by_node["tip_A"] == "S189/K158" + assert m.values_by_node["tip_B"] == "T189/N158" + + +def test_color_tree_by_haplotype_duplicate_sites_raises(): + root = _load(_haplotype_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "genotype:HA1:158,158") + assert "duplicates" in str(exc.value) + + +def test_color_tree_by_genotype_site_int_form_required(): + root = _load(_haplotype_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "genotype:HA1:foo") + assert "positive integer" in str(exc.value) + + +# ----------------------------------------------------------------------------- +# scale resolution: Auspice meta vs default palette +# ----------------------------------------------------------------------------- + + +def test_color_tree_by_uses_auspice_scale_when_present(): + root = _load(_attr_auspice()) + meta = { + "colorings": [ + { + "key": "subclade", + "type": "categorical", + "scale": [["X", "#ff0000"], ["Y", "#00ff00"], ["Z", "#0000ff"]], + }, + ], + } + m = _color.compute_node_color_values(root, "subclade", auspice_meta=meta) + color_for = dict(zip(m.domain, m.range_)) + assert color_for["X"] == "#ff0000" + assert color_for["Y"] == "#00ff00" + assert color_for["Z"] == "#0000ff" + + +def test_color_tree_by_falls_back_to_default_palette_for_unmapped(): + """Auspice scale only covers X and Z; Y must take the first slot of the + Auspice fallback palette sized to the unmapped count (1 here, so + `_AUSPICE_PALETTE[1][0]`). Auspice-mapped slots don't consume + fallback-palette indices.""" + root = _load(_attr_auspice()) + meta = { + "colorings": [ + { + "key": "subclade", + "type": "categorical", + "scale": [["X", "#ff0000"], ["Z", "#0000ff"]], + }, + ], + } + m = _color.compute_node_color_values(root, "subclade", auspice_meta=meta) + color_for = dict(zip(m.domain, m.range_)) + assert color_for["X"] == "#ff0000" + assert color_for["Z"] == "#0000ff" + # 1 unmapped category (Y) -> _AUSPICE_PALETTE[1] -> single color. + assert color_for["Y"] == _color._AUSPICE_PALETTE[1][0] + + +def test_color_tree_by_legend_title_uses_auspice_meta_title(): + root = _load(_attr_auspice()) + meta = { + "colorings": [{"key": "subclade", "type": "categorical", "title": "Subclade"}] + } + m = _color.compute_node_color_values(root, "subclade", auspice_meta=meta) + assert m.legend_title == "Subclade" + + +def test_color_tree_by_genotype_ignores_auspice_scale(): + """Even if meta.colorings happens to have a 'genotype' entry, the + genotype path doesn't consult it — colors come from `_AUSPICE_PALETTE` + sized to the category count.""" + root = _load(_genotype_auspice()) + meta = { + "colorings": [ + { + "key": "genotype", + "type": "categorical", + "scale": [["K158", "#ff0000"]], + }, + ], + } + m = _color.compute_node_color_values(root, "genotype:HA1:158", auspice_meta=meta) + color_for = dict(zip(m.domain, m.range_)) + # K158 should NOT pick up the Auspice color; it gets a palette slot. + assert color_for["K158"] != "#ff0000" + n_real = sum(1 for c in m.domain if c != "unknown") + assert color_for["K158"] in _color._AUSPICE_PALETTE[n_real] + + +def test_color_tree_by_no_auspice_meta_uses_default_palette(): + """No `auspice_meta` and a node-attr spec -> all real categories come + from `_AUSPICE_PALETTE[N]` where N is the number of real categories.""" + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade", auspice_meta=None) + n_real = sum(1 for c in m.domain if c != "unknown") + palette_n = _color._AUSPICE_PALETTE[n_real] + for cat, col in zip(m.domain, m.range_): + if cat == "unknown": + assert col == "#888888" + else: + assert col in palette_n + + +def test_color_tree_by_categories_sorted_by_descending_frequency(): + """Categories should be ordered by descending count so the most common + one ends up at index 0 (where Auspice's per-N palette puts its + deepest-blue start).""" + # Build a tree where subclade frequencies are X=4 (ROOT, INT_LEFT, A, B), + # Y=2 (INT_RIGHT, C), Z=1 (D). After sorting by descending count we + # expect domain[:3] == ["X", "Y", "Z"]. + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + real = [c for c in m.domain if c != "unknown"] + assert real == ["X", "Y", "Z"] + + +def test_color_tree_by_default_palette_matches_auspice_for_six_categories(): + """For a 6-category attribute, the resolved range must equal Auspice's + `colors[6]` exactly. Pins the visual match against Nextstrain.""" + # Build a synthetic tree with 6 categories at distinct frequencies so + # ordering is unambiguous. + d = { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0, "clade": {"value": "C1"}}, + "children": [ + { + "name": f"tip_{i}", + "node_attrs": {"div": 0.01 * i, "clade": {"value": v}}, + } + # Frequencies: C1=6, C2=5, C3=4, C4=3, C5=2, C6=1 -> total 21 + for i, v in enumerate( + ["C1"] * 5 # plus the one on ROOT, total C1=6 + + ["C2"] * 5 + + ["C3"] * 4 + + ["C4"] * 3 + + ["C5"] * 2 + + ["C6"] + ) + ], + }, + } + root = _load(d) + m = _color.compute_node_color_values(root, "clade") + assert m.domain == ["C1", "C2", "C3", "C4", "C5", "C6"] + assert tuple(m.range_) == _color._AUSPICE_PALETTE[6] + + +def test_color_tree_by_gray_reserved_for_unknown(): + """For a tree with non-missing categories plus 'unknown', gray (#888888) + must appear *only* at the 'unknown' slot — never inside the per-N + Auspice palette so a fallback-mapped category cannot collide with it.""" + d = _attr_auspice() + # Drop subclade off one tip so "unknown" enters the legend. + d["tree"]["children"][0]["children"][0]["node_attrs"].pop("subclade") + root = _load(d) + m = _color.compute_node_color_values(root, "subclade") + gray_positions = [i for i, c in enumerate(m.range_) if c == "#888888"] + assert len(gray_positions) == 1 + assert m.domain[gray_positions[0]] == "unknown" + # And gray is not anywhere in the Auspice palette. + for entry in _color._AUSPICE_PALETTE: + assert "#888888" not in entry + assert "#7f7f7f" not in entry + + +# ----------------------------------------------------------------------------- +# plot()-level: encoding placement, legend orient, default-none +# ----------------------------------------------------------------------------- + + +def _vertical_chart(strains: list[str]) -> alt.Chart: + df = pd.DataFrame({"strain": strains, "titer": [1.0, 2.0, 4.0, 8.0]}) + return ( + alt.Chart(df) + .mark_circle() + .encode(x="titer:Q", y=alt.Y("strain:N")) + .properties(width=200, height=200) + ) + + +def _kw(): + return dict( + chart_strain_field="strain", tree_strain_field="name", branch_length="div" + ) + + +def _find_color_encodings(node: Any) -> list[tuple[str, dict]]: + """Recursively find every encoding-block 'color' on the spec, returning + (json-pointer-ish path, encoding dict).""" + hits: list[tuple[str, dict]] = [] + + def walk(o: Any, path: str) -> None: + if isinstance(o, dict): + enc = o.get("encoding") + if isinstance(enc, dict) and "color" in enc: + hits.append((path, enc["color"])) + for k, v in o.items(): + walk(v, f"{path}.{k}") + elif isinstance(o, list): + for i, v in enumerate(o): + walk(v, f"{path}[{i}]") + + walk(node, "") + return hits + + +def _tree_panel_color_encodings(out) -> list[dict]: + """Color encodings on the tree panel only (which is hconcat[0] for our + 'left'-default vertical layout). Filtered to those backed by the + 'color_value' field — pre-existing chart colors are not.""" + spec = out.to_dict() + panels = spec.get("hconcat") or spec.get("vconcat") or [] + assert panels + tree_panel = panels[0] + return [ + enc + for _, enc in _find_color_encodings(tree_panel) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + + +def test_color_tree_by_default_none_no_color_encoding(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + ) + encs = _tree_panel_color_encodings(out) + assert encs == [] + + +def test_color_tree_by_legend_orient_bottom(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + assert len(encs) >= 1 + for enc in encs: + legend = enc.get("legend") or {} + assert legend.get("orient") == "bottom" + + +def test_color_tree_by_legend_title_is_spec_string(): + # No auspice_meta.colorings.title -> falls back to the literal spec. + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + assert encs[0]["legend"]["title"] == "subclade" + + +def test_color_tree_by_branches_and_tips_share_field(): + """The mark_rule (branches) and mark_circle (tips) both reference the + same `color_value:N` field, so Altair collapses them into one legend.""" + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + # Two encodings: one on the seg_df rule, one on the tips_df circle. + assert len(encs) == 2 + fields = {enc["field"] for enc in encs} + assert fields == {"color_value"} + + +def test_color_tree_by_legend_hides_internal_only_unknown(): + """When only an internal node lacks the attribute, the rendered spec's + legend.values must be set and must not contain 'unknown'. Internal-node + branches still render gray via the unchanged scale.""" + d = _attr_auspice() + d["tree"]["node_attrs"].pop("subclade") # drop subclade off ROOT only + out = tree_annotated_plot.plot( + d, + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + assert len(encs) >= 1 + for enc in encs: + legend = enc.get("legend") or {} + assert "values" in legend + assert "unknown" not in legend["values"] + # Scale still carries "unknown" so the gray rendering works. + assert "unknown" in enc["scale"]["domain"] + + +# ----------------------------------------------------------------------------- +# CLI +# ----------------------------------------------------------------------------- + + +def _write_chart_json(path, chart: alt.Chart) -> None: + chart.save(str(path)) + + +def _write_tree_json(path, tree: dict) -> None: + import json + + path.write_text(json.dumps(tree)) + + +def _run_cli(args, expect_success: bool = True): + from click.testing import CliRunner + + from tree_annotated_plot import cli as cli_module + + runner = CliRunner() + result = runner.invoke(cli_module.main, args, catch_exceptions=False) + if expect_success and result.exit_code != 0: + raise AssertionError(f"CLI exit {result.exit_code}\n{result.output}") + return result + + +def _cli_setup(tmp_path, tree_dict: dict, chart: alt.Chart): + tree_path = tmp_path / "tree.json" + chart_path = tmp_path / "chart.json" + out_path = tmp_path / "out.json" + _write_tree_json(tree_path, tree_dict) + _write_chart_json(chart_path, chart) + return tree_path, chart_path, out_path + + +def test_cli_color_tree_by_subclade(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "subclade", + ] + ) + assert out_path.exists() + import json + + spec = json.loads(out_path.read_text()) + encs = [ + enc + for _, enc in _find_color_encodings(spec.get("hconcat", [{}])[0]) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + assert encs + + +def test_cli_color_tree_by_genotype(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, + _genotype_auspice(), + _vertical_chart(["tip_A", "tip_B", "tip_C", "tip_D"]), + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "genotype:HA1:158", + ] + ) + assert out_path.exists() + + +def test_cli_color_tree_by_haplotype(tmp_path): + """Comma in `genotype:HA1:158,189` must survive click's argument parsing.""" + tree_path, chart_path, out_path = _cli_setup( + tmp_path, + _haplotype_auspice(), + _vertical_chart(["tip_A", "tip_B", "tip_C", "tip_D"]), + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "genotype:HA1:158,189", + ] + ) + assert out_path.exists() + import json + + spec = json.loads(out_path.read_text()) + encs = [ + enc + for _, enc in _find_color_encodings(spec.get("hconcat", [{}])[0]) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + assert encs + # Domain should contain at least one slash-joined haplotype label. + domain = encs[0].get("scale", {}).get("domain", []) + assert any("/" in cat for cat in domain)