diff --git a/CHANGELOG.md b/CHANGELOG.md index 65745d3..4391fcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `strain_label_font_size` (default `10`), `strain_label_font_weight` (default `"normal"`), and `shift_tree_loc` (default `0`) for tuning the size, weight, and placement of the connected labels. +- `color_tree_by` (default `None`): color the tree's branches and tip + circles by an Auspice node attribute (e.g. `"subclade"`) or by the + inferred genotype state at one or more sites + (e.g. `"genotype:HA1:158"` or `"genotype:HA1:158,189"`). Colors, + category ordering, and the bottom-of-plot legend match the + Nextstrain view of the same tree. + +### Changed + +- Default `tree_line_width` bumped from `1.5` to `2`, default + `tree_node_size` from `28` to `45`. Tree branch lines and tip + circles are now drawn at full opacity. The thicker / fuller + defaults read better when the tree is colored (the prior values + were tuned for unicolor black trees). ## [0.1.0] - 2026-05-04 diff --git a/docs/examples.md b/docs/examples.md index eeee4ba..c159454 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -103,6 +103,11 @@ chart-builder wrote them (fonts, ticks, axis title, and all), and the tree's dashed leader lines stop at the tree panel's chart-facing edge. +The H3N2 example above is rendered with `color_tree_by="subclade"`, +which colors the tree's branches and tip circles by the +`node_attrs.subclade` value at each node and adds a categorical legend +below the plot. See "Color the tree" below for the full set of options. + ### Optional: connect leaders all the way to the labels If you'd prefer the dashed leaders to run flush into the strain @@ -131,6 +136,40 @@ CLI flags: `--connect-leader-to-label --strain-label-font-size 9 --shift-tree-loc 60`. In Python: `connect_leader_to_label=True, strain_label_font_size=9, shift_tree_loc=60`. +### Color the tree + +Pass `color_tree_by` to color the tree's branches and tip circles by +any property the Auspice JSON exposes — broadly, anything that +appears in the "Color By" dropdown on the Nextstrain view of the same +tree. Two forms are supported: + +- A named attribute, e.g. `color_tree_by="subclade"` (used in the + example above). Common alternatives include `"clade_membership"`, + `"region"`, `"country"` — whichever the tree provides. +- A genotype at one or more sites in a gene. For a single site, + `color_tree_by="genotype:HA1:158"` colors each tip by the amino + acid at HA1 site 158. A comma-separated list gives a haplotype: + `"genotype:HA1:158,189"`. Sites that don't vary in the tree are + dropped from the haplotype label. + +Colors match what you'd see on the Nextstrain view of the same tree — +either from the JSON's palette information when the build provides it, +or from the same default palette Auspice uses when it doesn't. +Categories are ordered by descending frequency in both cases. Missing +values render in gray, and the legend is drawn at the bottom of the +combined plot. + +The example below colors the same H3N2 chart by genotype at HA1 +site 158, which has two mutations in the tree (`N158K`, `N158D`) and +so renders three states (N, K, D): + +![H3N2 combined chart, colored by genotype HA1:158](images/h3n2_combined_genotype_158.svg) + +[Open the interactive chart in a new tab →](charts/h3n2_combined_genotype_158.html){target="_blank"} + +CLI flag: `--color-tree-by genotype:HA1:158`. In Python: +`color_tree_by="genotype:HA1:158"`. + ### Reproduce — command line ```bash @@ -145,6 +184,7 @@ tree-annotated-plot \ --tree-size 140 \ --scale-bar \ --branch-length-units substitutions \ + --color-tree-by subclade \ --output examples/data/h3n2_combined.json ``` @@ -162,6 +202,7 @@ out = tree_annotated_plot.plot( tree_size=140, scale_bar=True, branch_length_units="substitutions", + color_tree_by="subclade", ) ``` diff --git a/scripts/generate_docs_assets.py b/scripts/generate_docs_assets.py index d51c13b..976ce75 100644 --- a/scripts/generate_docs_assets.py +++ b/scripts/generate_docs_assets.py @@ -110,9 +110,7 @@ def _render_kikawa() -> None: # Render the bare chart (no tree) so the docs page can show what # the chart looks like before tree-annotated-plot wraps it. _save_pair(chart, f"{basename}_chart_only") - out = tree_annotated_plot.plot( - DATA_DIR / f"flu-seqneut-2025to2026_{subtype}.json", - chart, + plot_kwargs = dict( chart_strain_field="axis_label", tree_strain_field="derived_haplotype", branch_length="div", @@ -120,6 +118,16 @@ def _render_kikawa() -> None: scale_bar=True, branch_length_units="substitutions", ) + if subtype == "H3N2": + # Color H3N2 by subclade so the docs SVG matches what users see + # on Nextstrain. The Auspice JSON's meta.colorings.subclade has + # no `scale` defined, so colors come from the default palette. + plot_kwargs["color_tree_by"] = "subclade" + out = tree_annotated_plot.plot( + DATA_DIR / f"flu-seqneut-2025to2026_{subtype}.json", + chart, + **plot_kwargs, + ) _save_pair(out, f"{basename}_combined") # H3N2 again, with `connect_leader_to_label=True` and a 9-point label @@ -146,9 +154,35 @@ def _render_kikawa() -> None: connect_leader_to_label=True, strain_label_font_size=9, shift_tree_loc=60, + color_tree_by="subclade", ) _save_pair(out, "h3n2_combined_label_connect") + # H3N2 once more, colored by genotype at HA1 site 158: same chart and + # default layout as `h3n2_combined`, with `color_tree_by` switched to + # the genotype form. Site 158 has two mutations (N158K, N158D) in the + # tree, so this renders three states (N, K, D). + h3n2_chart_genotype = builder.make_chart( + subtype="H3N2", + chart_type="iqr", + titers=titers, + viruses=viruses, + metadata=metadata, + all_cohorts=all_cohorts, + ) + out = tree_annotated_plot.plot( + DATA_DIR / "flu-seqneut-2025to2026_H3N2.json", + h3n2_chart_genotype, + chart_strain_field="axis_label", + tree_strain_field="derived_haplotype", + branch_length="div", + tree_size=140, + scale_bar=True, + branch_length_units="substitutions", + color_tree_by="genotype:HA1:158", + ) + _save_pair(out, "h3n2_combined_genotype_158") + def main() -> None: """Render every example to SVG + interactive HTML under `docs/`.""" diff --git a/src/tree_annotated_plot/_color.py b/src/tree_annotated_plot/_color.py new file mode 100644 index 0000000..87a681c --- /dev/null +++ b/src/tree_annotated_plot/_color.py @@ -0,0 +1,469 @@ +"""Resolve per-node color categories and the associated `alt.Scale` arrays. + +Public entry point: :func:`compute_node_color_values`. + +Spec mini-syntax (see `PlotConfig.color_tree_by`): + +- ``""`` — look up ``node_attrs[]`` on each node, auto-unwrapping + the Auspice ``{"value": ...}`` convention. Missing → ``"unknown"``. +- ``"genotype::"`` — single-site genotype state inferred by walking + ``branch_attrs.mutations[]`` from the root. Site with zero mutations + in the tree → all nodes are labeled ``""``. +- ``"genotype::,,..."`` — haplotype across sites. Per-node + label is the ``/``-joined letter+site for each *varying* site, in + user-supplied order. Sites that are invariant in the tree are dropped from + the label. If all requested sites are invariant, every node gets + ``""``. + +Color resolution prefers ``meta.colorings[].scale`` from the Auspice JSON +when defined (matches Nextstrain views). When unset, partial, or missing for +the requested key, the per-N table in :data:`_AUSPICE_PALETTE` fills in — the +same hand-tuned categorical palette Auspice's frontend uses for trees that +don't ship explicit color scales, so output still matches the Nextstrain +view. Categories are ordered by descending frequency (ties broken +alphabetically) before being mapped positionally onto the palette, again +matching Auspice's behavior. ``"unknown"`` always renders as ``#888888`` +(gray) and is reserved — no entry in :data:`_AUSPICE_PALETTE` is gray, so +a category drawn from the palette can never be confused with missing. +""" + +from __future__ import annotations + +import re +from collections import Counter +from dataclasses import dataclass +from typing import Iterator + +from ._tree import TreeNode + +# Auspice's frontend categorical palette, indexed by category count. +# `_AUSPICE_PALETTE[N]` is a hand-tuned `N`-color palette, perceptually spaced +# along a viridis-like sweep (purple → blue → green → yellow → orange → red); +# entries 0..36 are reproduced from Auspice's `src/util/globals.js`. We cap at +# 36 (matching Auspice's own `colors[colors.length - 1]` fallback) and reuse +# the largest palette for >36-category trees. Reproducing this table lets +# rendered output match Nextstrain's view of the same tree even when the +# Auspice JSON omits an explicit `meta.colorings[].scale`. +# +# Source: https://github.com/nextstrain/auspice/blob/master/src/util/globals.js +# Auspice is licensed under AGPL-3.0; the per-N color tables are reproduced +# here as factual color data with attribution. +# fmt: off +_AUSPICE_PALETTE: tuple[tuple[str, ...], ...] = ( + (), + ("#4C90C0",), + ("#4C90C0", "#CBB742"), + ("#4988C5", "#7EB876", "#CBB742"), + ("#4580CA", "#6BB28D", "#AABD52", "#DFA43B"), + ("#4377CD", "#61AB9D", "#94BD61", "#CDB642", "#E68133"), + ("#416DCE", "#59A3AA", "#84BA6F", "#BBBC49", "#E29D39", "#E1502A"), + ("#3F63CF", "#529AB6", "#75B681", "#A6BE55", "#D4B13F", "#E68133", "#DC2F24"), + ("#3E58CF", "#4B8EC1", "#65AE96", "#8CBB69", "#B8BC4A", "#DCAB3C", "#E67932", "#DC2F24"), + ("#3F4DCB", "#4681C9", "#5AA4A8", "#78B67E", "#9EBE5A", "#C5B945", "#E0A23A", "#E67231", "#DC2F24"), + ("#4042C7", "#4274CE", "#5199B7", "#69B091", "#88BB6C", "#ADBD51", "#CEB541", "#E39B39", "#E56C2F", "#DC2F24"), + ("#4137C2", "#4066CF", "#4B8DC2", "#5DA8A3", "#77B67F", "#96BD60", "#B8BC4B", "#D4B13F", "#E59638", "#E4672F", "#DC2F24"), + ("#462EB9", "#3E58CF", "#4580CA", "#549DB2", "#69B091", "#83BA70", "#A2BE57", "#C1BA47", "#D9AD3D", "#E69136", "#E4632E", "#DC2F24"), + ("#4B26B1", "#3F4ACA", "#4272CE", "#4D92BF", "#5DA8A3", "#74B583", "#8EBC66", "#ACBD51", "#C8B944", "#DDA93C", "#E68B35", "#E3602D", "#DC2F24"), + ("#511EA8", "#403DC5", "#4063CF", "#4785C7", "#559EB1", "#67AF94", "#7EB877", "#98BD5E", "#B4BD4C", "#CDB642", "#DFA53B", "#E68735", "#E35D2D", "#DC2F24"), + ("#511EA8", "#403AC4", "#3F5ED0", "#457FCB", "#5098B9", "#60AA9F", "#73B583", "#8BBB6A", "#A4BE56", "#BDBB48", "#D3B240", "#E19F3A", "#E68234", "#E25A2C", "#DC2F24"), + ("#511EA8", "#4138C3", "#3E59CF", "#4379CD", "#4D92BE", "#5AA5A8", "#6BB18E", "#7FB975", "#96BD5F", "#AFBD4F", "#C5B945", "#D8AE3E", "#E39B39", "#E67D33", "#E2572B", "#DC2F24"), + ("#511EA8", "#4236C1", "#3F55CE", "#4273CE", "#4A8CC2", "#569FAF", "#64AD98", "#76B680", "#8BBB6A", "#A1BE58", "#B7BC4B", "#CCB742", "#DCAB3C", "#E59638", "#E67932", "#E1552B", "#DC2F24"), + ("#511EA8", "#4335BF", "#3F51CC", "#416ECE", "#4887C6", "#529BB6", "#5FA9A0", "#6EB389", "#81B973", "#95BD61", "#AABD52", "#BFBB48", "#D1B340", "#DEA63B", "#E69237", "#E67531", "#E1522A", "#DC2F24"), + ("#511EA8", "#4333BE", "#3F4ECB", "#4169CF", "#4682C9", "#4F96BB", "#5AA5A8", "#68AF92", "#78B77D", "#8BBB6A", "#9EBE59", "#B3BD4D", "#C5B945", "#D5B03F", "#E0A23A", "#E68D36", "#E67231", "#E1502A", "#DC2F24"), + ("#511EA8", "#4432BD", "#3F4BCA", "#4065CF", "#447ECC", "#4C91BF", "#56A0AE", "#63AC9A", "#71B486", "#81BA72", "#94BD62", "#A7BE54", "#BABC4A", "#CBB742", "#D9AE3E", "#E29E39", "#E68935", "#E56E30", "#E14F2A", "#DC2F24"), + ("#511EA8", "#4531BC", "#3F48C9", "#3F61D0", "#4379CD", "#4A8CC2", "#539CB4", "#5EA9A2", "#6BB18E", "#7AB77B", "#8BBB6A", "#9CBE5B", "#AFBD4F", "#C0BA47", "#CFB541", "#DCAB3C", "#E39B39", "#E68534", "#E56B2F", "#E04D29", "#DC2F24"), + ("#511EA8", "#4530BB", "#3F46C8", "#3F5ED0", "#4375CD", "#4988C5", "#5098B9", "#5AA5A8", "#66AE95", "#73B583", "#82BA71", "#93BC62", "#A4BE56", "#B5BD4C", "#C5B945", "#D3B240", "#DEA73B", "#E59738", "#E68234", "#E4682F", "#E04C29", "#DC2F24"), + ("#511EA8", "#462FBA", "#3F44C8", "#3E5BD0", "#4270CE", "#4784C8", "#4E95BD", "#57A1AD", "#61AB9C", "#6DB38A", "#7BB879", "#8BBB6A", "#9BBE5C", "#ABBD51", "#BBBC49", "#CBB843", "#D6AF3E", "#DFA43B", "#E69537", "#E67F33", "#E4662E", "#E04A29", "#DC2F24"), + ("#511EA8", "#462EB9", "#4042C7", "#3E58CF", "#416DCE", "#4580CA", "#4C90C0", "#549DB2", "#5DA8A3", "#69B091", "#75B681", "#83BA70", "#92BC63", "#A2BE57", "#B2BD4D", "#C1BA47", "#CEB541", "#D9AD3D", "#E1A03A", "#E69136", "#E67C32", "#E4632E", "#E04929", "#DC2F24"), + ("#511EA8", "#462EB9", "#4040C6", "#3F55CE", "#4169CF", "#447DCC", "#4A8CC2", "#529AB7", "#5AA5A8", "#64AD98", "#70B487", "#7DB878", "#8BBB6A", "#99BD5D", "#A9BD53", "#B7BC4B", "#C5B945", "#D1B340", "#DCAB3C", "#E29D39", "#E68D36", "#E67932", "#E3612D", "#E04828", "#DC2F24"), + ("#511EA8", "#472DB8", "#403EC6", "#3F53CD", "#4066CF", "#4379CD", "#4989C5", "#4F97BB", "#57A1AD", "#61AA9E", "#6BB18E", "#77B67F", "#84BA70", "#92BC64", "#A0BE58", "#AFBD4F", "#BCBB49", "#CAB843", "#D4B13F", "#DEA83C", "#E39B39", "#E68A35", "#E67732", "#E35F2D", "#DF4728", "#DC2F24"), + ("#511EA8", "#472CB7", "#403DC5", "#3F50CC", "#4063CF", "#4375CD", "#4785C7", "#4D93BE", "#559EB1", "#5DA8A3", "#67AF94", "#72B485", "#7EB877", "#8BBB6A", "#98BD5E", "#A6BE55", "#B4BD4C", "#C1BA47", "#CDB642", "#D7AF3E", "#DFA53B", "#E49838", "#E68735", "#E67431", "#E35D2D", "#DF4628", "#DC2F24"), + ("#511EA8", "#482CB7", "#403BC5", "#3F4ECB", "#3F61D0", "#4272CE", "#4682C9", "#4C90C0", "#529BB5", "#5AA5A8", "#63AC9A", "#6DB28B", "#78B77D", "#84BA6F", "#91BC64", "#9EBE59", "#ACBD51", "#B9BC4A", "#C5B945", "#D0B441", "#DAAD3D", "#E0A23A", "#E59637", "#E68434", "#E67231", "#E35C2C", "#DF4528", "#DC2F24"), + ("#511EA8", "#482BB6", "#403AC4", "#3F4CCB", "#3F5ED0", "#426FCE", "#457FCB", "#4A8CC2", "#5098B9", "#58A2AC", "#60AA9F", "#69B091", "#73B583", "#7FB976", "#8BBB6A", "#97BD5F", "#A4BE56", "#B1BD4E", "#BDBB48", "#C9B843", "#D3B240", "#DCAB3C", "#E19F3A", "#E69337", "#E68234", "#E67030", "#E25A2C", "#DF4428", "#DC2F24"), + ("#511EA8", "#482BB6", "#4039C3", "#3F4ACA", "#3E5CD0", "#416CCE", "#447CCD", "#4989C4", "#4E96BC", "#559FB0", "#5DA8A4", "#66AE96", "#6FB388", "#7AB77C", "#85BA6F", "#91BC64", "#9DBE5A", "#AABD53", "#B6BD4B", "#C2BA46", "#CDB642", "#D6B03F", "#DDA83C", "#E29D39", "#E69036", "#E67F33", "#E56D30", "#E2592C", "#DF4428", "#DC2F24"), + ("#511EA8", "#482AB5", "#4138C3", "#3F48C9", "#3E59CF", "#4169CF", "#4379CD", "#4886C6", "#4D92BE", "#539CB4", "#5AA5A8", "#62AB9B", "#6BB18E", "#75B581", "#7FB975", "#8BBB6A", "#96BD5F", "#A2BE57", "#AFBD4F", "#BABC4A", "#C5B945", "#CFB541", "#D8AE3E", "#DFA63B", "#E39B39", "#E68D36", "#E67D33", "#E56B2F", "#E2572B", "#DF4328", "#DC2F24"), + ("#511EA8", "#492AB5", "#4137C2", "#3F47C9", "#3E57CE", "#4067CF", "#4376CD", "#4783C8", "#4C8FC0", "#519AB7", "#58A2AC", "#5FA9A0", "#68AF93", "#70B486", "#7BB77A", "#85BA6F", "#90BC65", "#9CBE5B", "#A8BE54", "#B3BD4D", "#BEBB48", "#C9B843", "#D2B340", "#DAAD3D", "#E0A33B", "#E49838", "#E68B35", "#E67B32", "#E5692F", "#E2562B", "#DF4227", "#DC2F24"), + ("#511EA8", "#492AB5", "#4236C1", "#3F45C8", "#3F55CE", "#4064CF", "#4273CE", "#4681CA", "#4A8CC2", "#4F97BA", "#569FAF", "#5CA7A4", "#64AD98", "#6DB28B", "#76B680", "#80B974", "#8BBB6A", "#96BD60", "#A1BE58", "#ACBD51", "#B7BC4B", "#C2BA46", "#CCB742", "#D4B13F", "#DCAB3C", "#E1A13A", "#E59638", "#E68835", "#E67932", "#E4672F", "#E1552B", "#DF4227", "#DC2F24"), + ("#511EA8", "#4929B4", "#4235C0", "#3F44C8", "#3F53CD", "#3F62CF", "#4270CE", "#457ECB", "#4989C4", "#4E95BD", "#549DB3", "#5AA5A8", "#61AB9C", "#69B090", "#72B485", "#7BB879", "#85BA6E", "#90BC65", "#9BBE5C", "#A6BE55", "#B1BD4E", "#BBBC49", "#C5B945", "#CEB541", "#D6AF3E", "#DDA93C", "#E29F39", "#E69537", "#E68634", "#E67732", "#E4662E", "#E1532B", "#DF4127", "#DC2F24"), + ("#511EA8", "#4929B4", "#4335BF", "#3F42C7", "#3F51CC", "#3F60D0", "#416ECE", "#447CCD", "#4887C6", "#4D92BF", "#529BB6", "#58A2AB", "#5FA9A0", "#66AE95", "#6EB389", "#77B67E", "#81B973", "#8BBB6A", "#95BD61", "#A0BE59", "#AABD52", "#B5BD4C", "#BFBB48", "#C9B843", "#D1B340", "#D8AE3E", "#DEA63B", "#E29C39", "#E69237", "#E68434", "#E67531", "#E4642E", "#E1522A", "#DF4127", "#DC2F24"), + ("#511EA8", "#4928B4", "#4334BF", "#4041C7", "#3F50CC", "#3F5ED0", "#416CCE", "#4379CD", "#4784C7", "#4B8FC1", "#5098B9", "#56A0AF", "#5CA7A4", "#63AC99", "#6BB18E", "#73B583", "#7CB878", "#86BB6E", "#90BC65", "#9ABD5C", "#A4BE56", "#AFBD4F", "#B9BC4A", "#C2BA46", "#CCB742", "#D3B240", "#DAAC3D", "#DFA43B", "#E39B39", "#E68F36", "#E68234", "#E67431", "#E4632E", "#E1512A", "#DF4027", "#DC2F24"), +) +# fmt: on + +_UNKNOWN = "unknown" +_NO_VARIATION = "" +_GRAY = "#888888" + +# Auspice mutation strings are like "N158K" / "*123A" / "-456N": one non-digit +# char, then digits, then one non-digit char. +_MUTATION_RE = re.compile(r"^(\D)(\d+)(\D)$") + + +@dataclass(frozen=True) +class ColorMapping: + """Resolved color information for a single `color_tree_by` invocation.""" + + values_by_node: dict[str, str] + domain: list[str] + range_: list[str] + legend_title: str + # When None, the legend shows the full domain. When set, it restricts the + # legend display without altering the scale — used to hide ``"unknown"`` + # when only internal nodes (not tips) lack the attribute, since the gray + # entry in that case just flags internal-node bookkeeping rather than + # any missing tip-level data. + legend_values: list[str] | None = None + + +def compute_node_color_values( + root: TreeNode, + color_spec: str, + auspice_meta: dict | None = None, +) -> ColorMapping: + """Walk the tree and resolve per-node color categories + scale arrays. + + Parameters + ---------- + root + The root of the tree to color. Internal-node identity is by + ``TreeNode.name`` (Auspice's ``NODE_xxxx``); tips by their resolved + strain name. + color_spec + The user-supplied spec string (see module docstring). + auspice_meta + The Auspice JSON's top-level ``meta`` dict, or ``None`` when no JSON + is available (caller passed a pre-built `TreeNode`). Used only to + consult ``meta.colorings[].scale`` and ``.title`` for node-attr + specs; ignored for genotype specs. + + Returns + ------- + ColorMapping + ``values_by_node[node.name]`` is the category string for each node. + ``domain`` and ``range_`` are parallel lists for ``alt.Scale``. + ``legend_title`` is the resolved legend header. ``legend_values``, + when set, restricts which categories appear in the legend (leaving + the scale untouched). + """ + parsed = _parse_color_spec(color_spec) + if parsed[0] == "attr": + _, key = parsed + values_by_node = _color_by_node_attr(root, key) + else: + _, gene, sites = parsed + values_by_node = _color_by_genotype(root, gene, sites) + + categories = _ordered_categories(values_by_node.values()) + domain, range_ = _resolve_scale(categories, parsed, auspice_meta) + legend_title = _resolve_legend_title(color_spec, parsed, auspice_meta) + legend_values = _resolve_legend_values(domain, values_by_node, root) + return ColorMapping( + values_by_node=values_by_node, + domain=domain, + range_=range_, + legend_title=legend_title, + legend_values=legend_values, + ) + + +def _resolve_legend_values( + domain: list[str], + values_by_node: dict[str, str], + root: TreeNode, +) -> list[str] | None: + """Decide whether to hide ``"unknown"`` from the legend. + + Returns ``None`` when the full domain should appear — either because + ``"unknown"`` isn't a category at all, or because at least one tip + carries it (in which case the user looking at a gray tip needs the + legend to explain it). When ``"unknown"`` is present but only on + internal nodes, returns the domain with ``"unknown"`` filtered out; + internal segments still render gray via the unchanged scale, but the + legend doesn't dangle a misleading entry. + """ + if _UNKNOWN not in domain: + return None + for node in _walk_nodes(root): + if node.is_tip and values_by_node.get(node.name) == _UNKNOWN: + return None + return [c for c in domain if c != _UNKNOWN] + + +def _parse_color_spec( + spec: str, +) -> tuple[str, str] | tuple[str, str, list[int]]: + """Parse the spec mini-syntax. See module docstring for the grammar.""" + if not isinstance(spec, str) or not spec or any(c.isspace() for c in spec): + raise ValueError( + f"color_tree_by={spec!r}: expected either a node_attrs key, " + '"genotype::", or ' + '"genotype::,,...".' + ) + if ":" not in spec: + return ("attr", spec) + parts = spec.split(":") + if len(parts) != 3 or parts[0] != "genotype" or not parts[1] or not parts[2]: + raise ValueError( + f"color_tree_by={spec!r}: expected either a node_attrs key, " + '"genotype::", or ' + '"genotype::,,...".' + ) + _, gene, site_str = parts + site_strs = site_str.split(",") + sites: list[int] = [] + for s in site_strs: + try: + n = int(s) + except ValueError as e: + raise ValueError( + f"color_tree_by={spec!r}: site {s!r} must be a positive integer." + ) from e + if n <= 0: + raise ValueError( + f"color_tree_by={spec!r}: site {s!r} must be a positive integer." + ) + sites.append(n) + seen: set[int] = set() + dups: set[int] = set() + for n in sites: + if n in seen: + dups.add(n) + else: + seen.add(n) + if dups: + raise ValueError( + f"color_tree_by={spec!r}: site list has duplicates ({sorted(dups)}); " + "each site must appear at most once." + ) + return ("genotype", gene, sites) + + +def _walk_nodes(root: TreeNode) -> Iterator[TreeNode]: + """Yield every node in the tree in pre-order.""" + yield root + for c in root.children: + yield from _walk_nodes(c) + + +def _color_by_node_attr(root: TreeNode, key: str) -> dict[str, str]: + """Resolve per-node category by reading `node_attrs[key]`.""" + values: dict[str, str] = {} + found = False + observed_keys: set[str] = set() + for node in _walk_nodes(root): + observed_keys.update(node.node_attrs.keys()) + attr = node.node_attrs.get(key) + if attr is None: + values[node.name] = _UNKNOWN + continue + if isinstance(attr, dict) and "value" in attr: + attr = attr["value"] + if attr is None or attr == "": + values[node.name] = _UNKNOWN + continue + found = True + values[node.name] = str(attr) + if not found: + observed = ", ".join(repr(k) for k in sorted(observed_keys)) + raise ValueError( + f"color_tree_by={key!r} is not a node_attrs key in this tree. " + f"Observed keys: [{observed}]" + ) + return values + + +def _color_by_genotype(root: TreeNode, gene: str, sites: list[int]) -> dict[str, str]: + """Resolve per-node category by walking branch_attrs.mutations[gene]. + + See :func:`_color_by_genotype_single_site` for the per-site walk. + Per-node label is the ``/``-joined ```` for each varying + site in user-supplied order; if every requested site is invariant, every + node gets ``""``. + """ + observed_genes: set[str] = set() + any_mutation = False + for node in _walk_nodes(root): + gene_map = node.branch_attrs.get("mutations", {}) or {} + if gene_map: + any_mutation = True + observed_genes.update(gene_map.keys()) + if not any_mutation: + raise ValueError( + f"color_tree_by='genotype:{gene}:{','.join(map(str, sites))}': " + "the Auspice JSON has no branch_attrs.mutations annotations." + ) + if gene not in observed_genes: + observed = ", ".join(repr(g) for g in sorted(observed_genes)) + raise ValueError( + f"color_tree_by='genotype:{gene}:{','.join(map(str, sites))}': " + f"gene {gene!r} not in branch_attrs.mutations. " + f"Observed genes: [{observed}]" + ) + + per_site_states: dict[int, dict[str, str] | None] = {} + for site in sites: + per_site_states[site] = _color_by_genotype_single_site(root, gene, site) + + varying_sites = [s for s in sites if per_site_states[s] is not None] + if not varying_sites: + # Every requested site is invariant in the tree. + return {node.name: _NO_VARIATION for node in _walk_nodes(root)} + + values: dict[str, str] = {} + for node in _walk_nodes(root): + parts = [] + for site in varying_sites: + states = per_site_states[site] + assert states is not None + parts.append(f"{states[node.name]}{site}") + values[node.name] = "/".join(parts) + return values + + +def _color_by_genotype_single_site( + root: TreeNode, gene: str, site: int +) -> dict[str, str] | None: + """Pre-order walk inferring the state at (gene, site) for every node. + + Returns ``None`` when the site has zero mutations in the tree (invariant — + the caller drops it from the haplotype label, or labels every node + ``""`` when *all* requested sites are invariant). + Otherwise returns ``{node.name: letter}``. + """ + # Pass 1: collect every (gene, site) mutation along with the path-from-root + # depth at which it was applied. The earliest mutation's "from" letter is + # the root's state at the site. + mutations_seen: list[tuple[int, str, str]] = [] # (depth, from_letter, to_letter) + + def collect(node: TreeNode, depth: int) -> None: + for mut in node.branch_attrs.get("mutations", {}).get(gene, []) or []: + m = _MUTATION_RE.match(mut) + if m is None: + continue + from_letter, site_str, to_letter = m.group(1), m.group(2), m.group(3) + if int(site_str) == site: + mutations_seen.append((depth, from_letter, to_letter)) + for c in node.children: + collect(c, depth + 1) + + collect(root, 0) + if not mutations_seen: + return None + + mutations_seen.sort(key=lambda t: t[0]) + root_letter = mutations_seen[0][1] + + # Pass 2: assign per-node state. + states: dict[str, str] = {} + + def assign(node: TreeNode, current: str) -> None: + for mut in node.branch_attrs.get("mutations", {}).get(gene, []) or []: + m = _MUTATION_RE.match(mut) + if m is None: + continue + site_str, to_letter = m.group(2), m.group(3) + if int(site_str) == site: + current = to_letter + states[node.name] = current + for c in node.children: + assign(c, current) + + assign(root, root_letter) + return states + + +def _ordered_categories(values: Iterator[str] | list[str]) -> list[str]: + """Return unique category labels in legend display order. + + Real categories are sorted by descending count (ties broken + alphabetically), matching Auspice's `sortedDomain` for non- + clade-membership traits — the most common category lands at index 0 + so it gets the first slot of :data:`_AUSPICE_PALETTE[N]`. + ``"unknown"`` (if present) is pinned to the end regardless of count. + ``""`` is only ever the sole category (when every + requested site is invariant); in that case it stands alone. + """ + counts = Counter(values) + if set(counts) == {_NO_VARIATION}: + return [_NO_VARIATION] + has_unknown = _UNKNOWN in counts + real = [c for c in counts if c != _UNKNOWN] + real.sort(key=lambda c: (-counts[c], c)) + if has_unknown: + real.append(_UNKNOWN) + return real + + +def _resolve_scale( + categories: list[str], + parsed_spec: tuple, + auspice_meta: dict | None, +) -> tuple[list[str], list[str]]: + """Build (domain, range_) parallel arrays for `alt.Scale`. + + Prefers ``meta.colorings[].scale`` for node-attr specs. Categories + not covered by the auspice scale (and every category for genotype specs + or when no `auspice_meta` is supplied) fall back to + :data:`_AUSPICE_PALETTE[K]`, where K is the number of *unmapped* + categories — so the fallback hues come from the same per-N palette + Auspice uses, capped at 36. ``"unknown"`` always maps to + :data:`_GRAY`. + """ + auspice_map: dict[str, str] = {} + if parsed_spec[0] == "attr" and auspice_meta is not None: + key = parsed_spec[1] + for c in auspice_meta.get("colorings", []) or []: + if c.get("key") == key: + scale = c.get("scale") or [] + for entry in scale: + # Each entry is [value, color]; skip malformed rows. + if ( + isinstance(entry, (list, tuple)) + and len(entry) == 2 + and isinstance(entry[0], str) + and isinstance(entry[1], str) + ): + auspice_map[entry[0]] = entry[1] + break + + real_categories = [c for c in categories if c != _UNKNOWN] + unmapped = [c for c in real_categories if c not in auspice_map] + palette_idx = min(len(unmapped), len(_AUSPICE_PALETTE) - 1) + fallback_palette = _AUSPICE_PALETTE[palette_idx] + + domain: list[str] = [] + range_: list[str] = [] + fallback_pos = 0 + for cat in categories: + if cat == _UNKNOWN: + domain.append(cat) + range_.append(_GRAY) + continue + domain.append(cat) + if cat in auspice_map: + range_.append(auspice_map[cat]) + else: + if fallback_palette: + range_.append(fallback_palette[fallback_pos % len(fallback_palette)]) + else: + range_.append(_GRAY) + fallback_pos += 1 + return domain, range_ + + +def _resolve_legend_title( + color_spec: str, + parsed_spec: tuple, + auspice_meta: dict | None, +) -> str: + """Use `meta.colorings[].title` for attr specs when present; else + the literal spec string.""" + if parsed_spec[0] == "attr" and auspice_meta is not None: + key = parsed_spec[1] + for c in auspice_meta.get("colorings", []) or []: + if c.get("key") == key: + title = c.get("title") + if isinstance(title, str) and title: + return title + break + return color_spec diff --git a/src/tree_annotated_plot/_config.py b/src/tree_annotated_plot/_config.py index 2eaaac4..e827f41 100644 --- a/src/tree_annotated_plot/_config.py +++ b/src/tree_annotated_plot/_config.py @@ -70,15 +70,15 @@ class PlotConfig: tree_line_width: Annotated[ float, - "Stroke width (px) for the tree's branch lines. Default 1.5.", - ] = 1.5 + "Stroke width (px) for the tree's branch lines. Default 2.", + ] = 2.0 tree_node_size: Annotated[ float, "Area (px²) of the small filled circles drawn at each tip. " - "Default 28. Setting tree_node_size=0 disables the tip-circle " + "Default 45. Setting tree_node_size=0 disables the tip-circle " "layer entirely.", - ] = 28 + ] = 45 leader_line_width: Annotated[ float, @@ -154,6 +154,17 @@ class PlotConfig: "connect_leader_to_label is off.", ] = 0 + color_tree_by: Annotated[ + str | None, + "Color the tree by an Auspice attribute. Pass a node_attrs key " + '(e.g. "subclade"), or "genotype::" / ' + '"genotype::,,..." (e.g. "genotype:HA1:158" ' + 'or "genotype:HA1:158,189") to color by the inferred genotype ' + "state at a site or haplotype across sites. A categorical legend " + "is drawn below the plot; missing values are gray. None (default) " + "leaves the tree black.", + ] = None + # Sidecar for Python-docstring-only prose, keyed by PlotConfig field name. # Empty by default — add an entry when a field's docstring entry needs more diff --git a/src/tree_annotated_plot/_plot.py b/src/tree_annotated_plot/_plot.py index 562d821..25b4a74 100644 --- a/src/tree_annotated_plot/_plot.py +++ b/src/tree_annotated_plot/_plot.py @@ -14,7 +14,7 @@ import altair as alt import pandas as pd -from . import _config, _tree +from . import _color, _config, _tree from ._config import PlotConfig, TreeLocation # Accepted chart input forms for the public `plot` function. @@ -35,8 +35,8 @@ def plot( branch_length: Literal["div", "num_date"], tree_size: int = 100, tree_location: TreeLocation | None = None, - tree_line_width: float = 1.5, - tree_node_size: float = 28, + tree_line_width: float = 2.0, + tree_node_size: float = 45, leader_line_width: float = 1.0, scale_bar: bool = False, branch_length_units: str | None = None, @@ -46,6 +46,7 @@ def plot( strain_label_font_size: float = 10.0, strain_label_font_weight: Literal["normal", "bold"] = "normal", shift_tree_loc: int = 0, + color_tree_by: str | None = None, ) -> alt.HConcatChart | alt.VConcatChart: """Return an Altair chart with a phylogenetic tree drawn alongside `chart`.""" return _build( @@ -68,6 +69,7 @@ def plot( strain_label_font_size=strain_label_font_size, strain_label_font_weight=strain_label_font_weight, shift_tree_loc=shift_tree_loc, + color_tree_by=color_tree_by, ), ) @@ -121,7 +123,7 @@ def _build( set on a config object); `config` carries every styling / behavior knob. Both surfaces converge here so they can never disagree. """ - root = _ensure_tree( + root, auspice_meta = _ensure_tree( tree, config.tree_strain_field, branch_length=config.branch_length, @@ -162,6 +164,12 @@ def _build( strain_dim = _coerce_dim( _chart_strain_dim(spec, axis_hits, axis), n_tips=len(tip_names) ) + if config.color_tree_by is not None: + color_mapping = _color.compute_node_color_values( + root, config.color_tree_by, auspice_meta=auspice_meta + ) + else: + color_mapping = None tree_chart = _build_tree_chart( root, n_tips=len(tip_names), @@ -180,6 +188,7 @@ def _build( strain_label_font_weight=config.strain_label_font_weight, shift_tree_loc=config.shift_tree_loc, tip_names=tip_names, + color_mapping=color_mapping, ) new_chart = _apply_tree_order_to_chart_object( @@ -256,10 +265,17 @@ def _ensure_tree( *, branch_length: str, strict_version: bool, -) -> _tree.TreeNode: +) -> tuple[_tree.TreeNode, dict | None]: + """Return ``(root, auspice_meta)``. + + ``auspice_meta`` is the loaded Auspice JSON's top-level ``meta`` dict, or + ``None`` when the caller passed a pre-built ``TreeNode`` (in which case + we have no JSON to read ``meta.colorings`` from, and color resolution + falls back to the default palette). + """ if isinstance(tree, _tree.TreeNode): - return tree - return _tree.load_auspice( + return tree, None + return _tree.load_auspice_with_meta( tree, tree_strain_field=tree_strain_field, branch_length=branch_length, @@ -1200,8 +1216,8 @@ def _build_tree_chart( strain_dim: int | float | alt.Step, strain_axis: str, tree_location: TreeLocation, - tree_line_width: float = 1.5, - tree_node_size: float = 28, + tree_line_width: float = 2.0, + tree_node_size: float = 45, leader_line_width: float = 1.0, scale_bar: bool = False, branch_length: str = "div", @@ -1211,6 +1227,7 @@ def _build_tree_chart( strain_label_font_weight: str = "normal", shift_tree_loc: int = 0, tip_names: list[str] | None = None, + color_mapping: _color.ColorMapping | None = None, ) -> alt.Chart: """Build the tree panel. @@ -1248,6 +1265,37 @@ def _build_tree_chart( branch_max = float(seg_df[["x", "x2"]].max().max()) branch_min = float(seg_df[["x", "x2"]].min().min()) + # When color_tree_by is set, attach the per-node color category to both + # frames. The `mark_rule` for branches (one row per `seg_df` segment) and + # the `mark_circle` for tips share the same color encoding so Altair + # collapses them into a single legend at the bottom. + if color_mapping is not None: + seg_df = seg_df.assign( + color_value=seg_df["color_node"].map(color_mapping.values_by_node) + ) + tips_df = tips_df.assign( + color_value=tips_df["name"].map(color_mapping.values_by_node) + ) + legend_kwargs: dict = { + "title": color_mapping.legend_title, + "orient": "bottom", + } + if color_mapping.legend_values is not None: + # Restrict the legend display without touching the scale, so + # internal-node segments still render gray when "unknown" is on + # the tree but no tip is. + legend_kwargs["values"] = list(color_mapping.legend_values) + color_enc = alt.Color( + "color_value:N", + scale=alt.Scale( + domain=list(color_mapping.domain), + range=list(color_mapping.range_), + ), + legend=alt.Legend(**legend_kwargs), + ) + else: + color_enc = None + # When connect_leader_to_label is on: # - All leaders extend to a single point: the panel's chart-facing edge # (`chart_edge_branch`). @@ -1363,20 +1411,34 @@ def _build_tree_chart( ) .encode(x=branch_enc, x2="x2:Q", y=tip_enc) ) + seg_enc_kwargs: dict = { + "x": branch_enc, + "x2": "x2:Q", + "y": tip_enc, + "y2": "y2:Q", + } + if color_enc is not None: + seg_enc_kwargs["color"] = color_enc layers.append( alt.Chart(seg_df) - .mark_rule(strokeWidth=tree_line_width) - .encode(x=branch_enc, x2="x2:Q", y=tip_enc, y2="y2:Q") + .mark_rule(strokeWidth=tree_line_width, opacity=1.0) + .encode(**seg_enc_kwargs) ) if tree_node_size > 0: + tip_enc_kwargs: dict = { + "x": branch_enc, + "y": tip_enc, + "tooltip": alt.Tooltip("name:N", title="strain"), + } + tip_mark_kwargs: dict = {"size": tree_node_size, "opacity": 1.0} + if color_enc is None: + tip_mark_kwargs["color"] = "black" + else: + tip_enc_kwargs["color"] = color_enc layers.append( alt.Chart(tips_df) - .mark_circle(size=tree_node_size, color="black") - .encode( - x=branch_enc, - y=tip_enc, - tooltip=alt.Tooltip("name:N", title="strain"), - ) + .mark_circle(**tip_mark_kwargs) + .encode(**tip_enc_kwargs) ) # Strain text label, drawn as two stacked layers: a white halo # (white fill + thick white stroke) under the visible text. The @@ -1455,20 +1517,34 @@ def _build_tree_chart( ) .encode(y=branch_enc, y2="x2:Q", x=tip_enc) ) + seg_enc_kwargs = { + "y": branch_enc, + "y2": "x2:Q", + "x": tip_enc, + "x2": "y2:Q", + } + if color_enc is not None: + seg_enc_kwargs["color"] = color_enc layers.append( alt.Chart(seg_df) - .mark_rule(strokeWidth=tree_line_width) - .encode(y=branch_enc, y2="x2:Q", x=tip_enc, x2="y2:Q") + .mark_rule(strokeWidth=tree_line_width, opacity=1.0) + .encode(**seg_enc_kwargs) ) if tree_node_size > 0: + tip_enc_kwargs = { + "y": branch_enc, + "x": tip_enc, + "tooltip": alt.Tooltip("name:N", title="strain"), + } + tip_mark_kwargs = {"size": tree_node_size, "opacity": 1.0} + if color_enc is None: + tip_mark_kwargs["color"] = "black" + else: + tip_enc_kwargs["color"] = color_enc layers.append( alt.Chart(tips_df) - .mark_circle(size=tree_node_size, color="black") - .encode( - y=branch_enc, - x=tip_enc, - tooltip=alt.Tooltip("name:N", title="strain"), - ) + .mark_circle(**tip_mark_kwargs) + .encode(**tip_enc_kwargs) ) # Strain text label as halo + visible text (see vertical-layout # comment above). diff --git a/src/tree_annotated_plot/_tree.py b/src/tree_annotated_plot/_tree.py index f499cba..e6cd8d1 100644 --- a/src/tree_annotated_plot/_tree.py +++ b/src/tree_annotated_plot/_tree.py @@ -23,12 +23,20 @@ class TreeNode: `y` is set by :func:`layout` and is the integer index of the node's tip (for tips) or the midpoint of its descendants' tip indices (for internal nodes). + + `node_attrs` and `branch_attrs` carry the raw Auspice JSON dicts unchanged + so downstream code (e.g. `_color`) can read them directly. `node_attrs.div` + has already been consumed by `_resolve_branch_length`; `node_attrs[X]` is + where things like `subclade` live; `branch_attrs.mutations[]` is + where mutation strings like `"N158K"` live. """ name: str x: float children: list["TreeNode"] = field(default_factory=list) y: float | None = None + node_attrs: dict = field(default_factory=dict) + branch_attrs: dict = field(default_factory=dict) @property def is_tip(self) -> bool: @@ -66,6 +74,29 @@ def load_auspice( ``False`` the same case becomes a ``warnings.warn``. A missing ``version`` field always warns and proceeds. """ + root, _ = load_auspice_with_meta( + source, + tree_strain_field=tree_strain_field, + branch_length=branch_length, + strict_version=strict_version, + ) + return root + + +def load_auspice_with_meta( + source: str | Path | dict, + *, + tree_strain_field: str, + branch_length: str = "div", + strict_version: bool = True, +) -> tuple[TreeNode, dict]: + """Like :func:`load_auspice`, but also returns the Auspice top-level + ``meta`` dict. + + Used by code paths that need ``meta.colorings`` (e.g. resolving the + color palette for ``color_tree_by``). Returns ``(root, meta)`` where + ``meta`` is ``data.get("meta", {})``. + """ _validate_tree_strain_field(tree_strain_field) if branch_length not in ("div", "num_date"): raise ValueError( @@ -85,7 +116,9 @@ def load_auspice( if "tree" not in data: raise ValueError("Auspice JSON must have a top-level 'tree' field") - return _parse_node(data["tree"], tree_strain_field, branch_length) + root = _parse_node(data["tree"], tree_strain_field, branch_length) + meta = data.get("meta", {}) or {} + return root, meta def _check_auspice_version(data: dict, *, strict_version: bool) -> None: @@ -160,7 +193,13 @@ def _parse_node(d: dict, tree_strain_field: str, branch_length: str) -> TreeNode children = [ _parse_node(c, tree_strain_field, branch_length) for c in children_dicts ] - return TreeNode(name=name, x=float(branch_value), children=children) + return TreeNode( + name=name, + x=float(branch_value), + children=children, + node_attrs=d.get("node_attrs", {}) or {}, + branch_attrs=d.get("branch_attrs", {}) or {}, + ) def _resolve_branch_length(node_dict: dict, branch_length: str) -> float: @@ -225,7 +264,13 @@ def _prune_recursive(node: TreeNode, keep_strains: set[str]) -> TreeNode | None: """ if node.is_tip: if node.name in keep_strains: - return TreeNode(name=node.name, x=node.x, children=[]) + return TreeNode( + name=node.name, + x=node.x, + children=[], + node_attrs=node.node_attrs, + branch_attrs=node.branch_attrs, + ) return None new_children: list[TreeNode] = [] @@ -239,7 +284,13 @@ def _prune_recursive(node: TreeNode, keep_strains: set[str]) -> TreeNode | None: if len(new_children) == 1: # Collapse this single-child internal into its child. return new_children[0] - return TreeNode(name=node.name, x=node.x, children=new_children) + return TreeNode( + name=node.name, + x=node.x, + children=new_children, + node_attrs=node.node_attrs, + branch_attrs=node.branch_attrs, + ) def tips(root: TreeNode) -> Iterator[TreeNode]: @@ -278,12 +329,20 @@ def _assign_internal_y(node: TreeNode) -> float: def segments(root: TreeNode) -> pd.DataFrame: """Build a DataFrame of line segments for drawing the tree. - Returns columns `x`, `x2`, `y`, `y2`. Each row is one segment: + Returns columns `x`, `x2`, `y`, `y2`, `color_node`. Each row is one + segment: - For each internal node, one vertical connector from the topmost to the - bottommost child (x == x2 == node.x). + bottommost child (x == x2 == node.x). `color_node` is the parent's + name — the connector takes the parent's color. - For each non-root node, one horizontal branch from its parent's x to its - own x at its own y (y == y2). + own x at its own y (y == y2). `color_node` is the *child*'s name — + the branch into a node is colored by that node, matching how Auspice + colors trees. + + `color_node` is the join key into the per-node color map produced by + `_color.compute_node_color_values`. When no coloring is requested, the + column is harmless (consumers ignore it). Assumes :func:`layout` has already been called on `root`. """ @@ -297,11 +356,25 @@ def walk(node: TreeNode) -> None: return child_ys = [c.y for c in node.children] rows.append( - {"x": node.x, "x2": node.x, "y": min(child_ys), "y2": max(child_ys)} + { + "x": node.x, + "x2": node.x, + "y": min(child_ys), + "y2": max(child_ys), + "color_node": node.name, + } ) for c in node.children: - rows.append({"x": node.x, "x2": c.x, "y": c.y, "y2": c.y}) + rows.append( + { + "x": node.x, + "x2": c.x, + "y": c.y, + "y2": c.y, + "color_node": c.name, + } + ) walk(c) walk(root) - return pd.DataFrame(rows, columns=["x", "x2", "y", "y2"]) + return pd.DataFrame(rows, columns=["x", "x2", "y", "y2", "color_node"]) diff --git a/tests/test_color_tree.py b/tests/test_color_tree.py new file mode 100644 index 0000000..0f6e79d --- /dev/null +++ b/tests/test_color_tree.py @@ -0,0 +1,800 @@ +"""Tests for `color_tree_by`: per-node attr / genotype / haplotype coloring, +Auspice-scale preference, default-palette fallback, gray-for-missing, +legend wiring.""" + +from __future__ import annotations + +from typing import Any + +import altair as alt +import pandas as pd +import pytest + +import tree_annotated_plot +from tree_annotated_plot import _color, _tree + + +def _attr_auspice() -> dict: + """Tiny tree with `subclade` on every node and tips A..D in two clades.""" + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0, "subclade": {"value": "X"}}, + "children": [ + { + "name": "INT_LEFT", + "node_attrs": {"div": 0.02, "subclade": {"value": "X"}}, + "children": [ + { + "name": "A", + "node_attrs": { + "div": 0.04, + "subclade": {"value": "X"}, + }, + }, + { + "name": "B", + "node_attrs": { + "div": 0.05, + "subclade": {"value": "X"}, + }, + }, + ], + }, + { + "name": "INT_RIGHT", + "node_attrs": {"div": 0.03, "subclade": {"value": "Y"}}, + "children": [ + { + "name": "C", + "node_attrs": { + "div": 0.06, + "subclade": {"value": "Y"}, + }, + }, + { + "name": "D", + "node_attrs": { + "div": 0.07, + "subclade": {"value": "Z"}, + }, + }, + ], + }, + ], + }, + } + + +def _genotype_auspice(*, mutations_at_158: bool = True) -> dict: + """Tree with HA1 mutations along selected branches. + + With `mutations_at_158=True` (default): + - tip_A inherits N158K via its branch (state K). + - tip_B is on a no-mutation branch (root state N). + - INT1 carries N158D, so its descendants tip_C, tip_D are state D. + + With `mutations_at_158=False`: same topology, but no HA1:158 mutation + anywhere — used for the invariant-site case. + """ + site_158_mut = "N158K" if mutations_at_158 else None + int1_mut = "N158D" if mutations_at_158 else None + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + { + "name": "tip_A", + "node_attrs": {"div": 0.04}, + "branch_attrs": ( + {"mutations": {"HA1": [site_158_mut]}} if site_158_mut else {} + ), + }, + {"name": "tip_B", "node_attrs": {"div": 0.05}}, + { + "name": "INT1", + "node_attrs": {"div": 0.02}, + "branch_attrs": ( + {"mutations": {"HA1": [int1_mut]}} if int1_mut else {} + ), + "children": [ + {"name": "tip_C", "node_attrs": {"div": 0.06}}, + {"name": "tip_D", "node_attrs": {"div": 0.07}}, + ], + }, + ], + }, + } + + +def _haplotype_auspice() -> dict: + """Tree carrying HA1 mutations at sites 158 *and* 189 along independent + branches, so a 2-site haplotype gives several distinct categories.""" + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + { + "name": "tip_A", + "node_attrs": {"div": 0.04}, + "branch_attrs": {"mutations": {"HA1": ["N158K"]}}, + }, + { + "name": "tip_B", + "node_attrs": {"div": 0.05}, + "branch_attrs": {"mutations": {"HA1": ["S189T"]}}, + }, + { + "name": "INT1", + "node_attrs": {"div": 0.02}, + "branch_attrs": {"mutations": {"HA1": ["N158K", "S189T"]}}, + "children": [ + {"name": "tip_C", "node_attrs": {"div": 0.06}}, + {"name": "tip_D", "node_attrs": {"div": 0.07}}, + ], + }, + ], + }, + } + + +def _load(d: dict) -> _tree.TreeNode: + return _tree.load_auspice(d, tree_strain_field="name", branch_length="div") + + +# ----------------------------------------------------------------------------- +# node_attrs path +# ----------------------------------------------------------------------------- + + +def test_color_tree_by_node_attr_assigns_per_tip(): + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + assert m.values_by_node["A"] == "X" + assert m.values_by_node["B"] == "X" + assert m.values_by_node["C"] == "Y" + assert m.values_by_node["D"] == "Z" + + +def test_color_tree_by_node_attr_internal_nodes(): + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + assert m.values_by_node["INT_LEFT"] == "X" + assert m.values_by_node["INT_RIGHT"] == "Y" + assert m.values_by_node["ROOT"] == "X" + + +def test_color_tree_by_node_attr_unwraps_value(): + # `node_attrs.subclade = {"value": "X"}` resolves to "X", not the dict. + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + assert "X" in m.values_by_node.values() + + +def test_color_tree_by_node_attr_missing_marked_unknown(): + d = _attr_auspice() + # Strip subclade off a single tip. + d["tree"]["children"][0]["children"][0]["node_attrs"].pop("subclade") + root = _load(d) + m = _color.compute_node_color_values(root, "subclade") + assert m.values_by_node["A"] == "unknown" + # Domain places "unknown" last and pairs it with #888888. + assert m.domain[-1] == "unknown" + assert m.range_[-1] == "#888888" + # A is a tip, so the legend must keep "unknown" visible. + assert m.legend_values is None + + +def test_color_tree_by_unknown_omitted_from_legend_when_only_internal(): + """When only internal nodes lack the attribute (every tip is annotated), + "unknown" stays in the scale (so internal segments render gray) but is + hidden from the legend.""" + d = _attr_auspice() + # Strip subclade off the ROOT internal node only — all tips remain + # annotated. + d["tree"]["node_attrs"].pop("subclade") + root = _load(d) + m = _color.compute_node_color_values(root, "subclade") + assert m.values_by_node["ROOT"] == "unknown" + # Domain still includes "unknown" so the seg_df row for ROOT renders gray. + assert "unknown" in m.domain + # But the legend display drops it. + assert m.legend_values is not None + assert "unknown" not in m.legend_values + # And the rest of the legend matches domain-minus-unknown. + assert m.legend_values == [c for c in m.domain if c != "unknown"] + + +def test_color_tree_by_unknown_absent_legend_unrestricted(): + """Fully annotated tree (no `"unknown"` anywhere) -> legend_values stays + None, i.e. the legend uses the full domain.""" + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + assert "unknown" not in m.values_by_node.values() + assert m.legend_values is None + + +def test_color_tree_by_missing_attr_raises_lists_keys(): + root = _load(_attr_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "nonexistent_field") + msg = str(exc.value) + assert "nonexistent_field" in msg + # Observed keys appear in the message — div and subclade at minimum. + assert "'div'" in msg + assert "'subclade'" in msg + + +# ----------------------------------------------------------------------------- +# genotype path: single-site +# ----------------------------------------------------------------------------- + + +def test_color_tree_by_genotype_root_state_inference(): + root = _load(_genotype_auspice()) + m = _color.compute_node_color_values(root, "genotype:HA1:158") + assert m.values_by_node["tip_A"] == "K158" # branch carries N158K + assert m.values_by_node["tip_B"] == "N158" # root state + assert m.values_by_node["tip_C"] == "D158" # inherited from INT1's N158D + assert m.values_by_node["tip_D"] == "D158" + + +def test_color_tree_by_genotype_multiple_mutations_along_path(): + d = { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + {"name": "tip_X", "node_attrs": {"div": 0.05}}, + { + "name": "INT1", + "node_attrs": {"div": 0.02}, + "branch_attrs": {"mutations": {"HA1": ["N158K"]}}, + "children": [ + { + "name": "tip_Y", + "node_attrs": {"div": 0.04}, + }, + { + "name": "tip_Z", + "node_attrs": {"div": 0.06}, + "branch_attrs": {"mutations": {"HA1": ["K158R"]}}, + }, + ], + }, + ], + }, + } + root = _load(d) + m = _color.compute_node_color_values(root, "genotype:HA1:158") + assert m.values_by_node["tip_X"] == "N158" # root, no muts on path + assert m.values_by_node["tip_Y"] == "K158" # parent N158K + assert m.values_by_node["tip_Z"] == "R158" # N158K then K158R + + +def test_color_tree_by_genotype_single_site_invariant_renders_no_variation(): + root = _load(_genotype_auspice(mutations_at_158=False)) + # The JSON has no mutations anywhere, so this fires the "no mutation + # annotations" error rather than the invariant path. Add a stray + # mutation at a different site so mutations *exist* in the tree but + # NOT at site 158. + d = _genotype_auspice(mutations_at_158=False) + d["tree"]["children"][0]["branch_attrs"] = {"mutations": {"HA1": ["S145N"]}} + root = _load(d) + m = _color.compute_node_color_values(root, "genotype:HA1:158") + # Every node gets the literal "" category. + assert set(m.values_by_node.values()) == {""} + assert m.domain == [""] + + +def test_color_tree_by_genotype_no_mutations_anywhere_raises(): + d = _genotype_auspice(mutations_at_158=False) + root = _load(d) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "genotype:HA1:158") + assert "no branch_attrs.mutations annotations" in str(exc.value) + + +def test_color_tree_by_genotype_missing_gene_raises_lists_genes(): + root = _load(_genotype_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "genotype:NONEXISTENT:158") + msg = str(exc.value) + assert "'NONEXISTENT'" in msg + assert "'HA1'" in msg + + +# ----------------------------------------------------------------------------- +# genotype path: haplotype +# ----------------------------------------------------------------------------- + + +def test_color_tree_by_haplotype_basic(): + root = _load(_haplotype_auspice()) + m = _color.compute_node_color_values(root, "genotype:HA1:158,189") + # tip_A: branch N158K, no 189 mut -> K158/S189 + # tip_B: branch S189T, no 158 mut -> N158/T189 + # INT1's children inherit both N158K and S189T -> K158/T189 + assert m.values_by_node["tip_A"] == "K158/S189" + assert m.values_by_node["tip_B"] == "N158/T189" + assert m.values_by_node["tip_C"] == "K158/T189" + assert m.values_by_node["tip_D"] == "K158/T189" + + +def test_color_tree_by_haplotype_drops_invariant_sites(): + # Same tree as the single-site test, but ask for a 2-site haplotype + # where only 158 has mutations: the haplotype label collapses to just + # the 158 locus. + d = _genotype_auspice() # only 158 has mutations + root = _load(d) + m = _color.compute_node_color_values(root, "genotype:HA1:158,189") + # 189 is invariant in this tree -> dropped from the label. + assert m.values_by_node["tip_A"] == "K158" + assert m.values_by_node["tip_B"] == "N158" + assert m.values_by_node["tip_C"] == "D158" + + +def test_color_tree_by_haplotype_all_invariant_renders_no_variation(): + d = _genotype_auspice(mutations_at_158=False) + # Need mutations to exist somewhere so we don't trip the + # "no mutations anywhere" guard. Add a HA1 mutation at a third site. + d["tree"]["children"][0]["branch_attrs"] = {"mutations": {"HA1": ["S145N"]}} + root = _load(d) + m = _color.compute_node_color_values(root, "genotype:HA1:158,189") + assert set(m.values_by_node.values()) == {""} + assert m.domain == [""] + + +def test_color_tree_by_haplotype_preserves_user_site_order(): + root = _load(_haplotype_auspice()) + m = _color.compute_node_color_values(root, "genotype:HA1:189,158") + # User wrote 189 first, so the label has 189 first. + assert m.values_by_node["tip_A"] == "S189/K158" + assert m.values_by_node["tip_B"] == "T189/N158" + + +def test_color_tree_by_haplotype_duplicate_sites_raises(): + root = _load(_haplotype_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "genotype:HA1:158,158") + assert "duplicates" in str(exc.value) + + +def test_color_tree_by_genotype_site_int_form_required(): + root = _load(_haplotype_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values(root, "genotype:HA1:foo") + assert "positive integer" in str(exc.value) + + +# ----------------------------------------------------------------------------- +# scale resolution: Auspice meta vs default palette +# ----------------------------------------------------------------------------- + + +def test_color_tree_by_uses_auspice_scale_when_present(): + root = _load(_attr_auspice()) + meta = { + "colorings": [ + { + "key": "subclade", + "type": "categorical", + "scale": [["X", "#ff0000"], ["Y", "#00ff00"], ["Z", "#0000ff"]], + }, + ], + } + m = _color.compute_node_color_values(root, "subclade", auspice_meta=meta) + color_for = dict(zip(m.domain, m.range_)) + assert color_for["X"] == "#ff0000" + assert color_for["Y"] == "#00ff00" + assert color_for["Z"] == "#0000ff" + + +def test_color_tree_by_falls_back_to_default_palette_for_unmapped(): + """Auspice scale only covers X and Z; Y must take the first slot of the + Auspice fallback palette sized to the unmapped count (1 here, so + `_AUSPICE_PALETTE[1][0]`). Auspice-mapped slots don't consume + fallback-palette indices.""" + root = _load(_attr_auspice()) + meta = { + "colorings": [ + { + "key": "subclade", + "type": "categorical", + "scale": [["X", "#ff0000"], ["Z", "#0000ff"]], + }, + ], + } + m = _color.compute_node_color_values(root, "subclade", auspice_meta=meta) + color_for = dict(zip(m.domain, m.range_)) + assert color_for["X"] == "#ff0000" + assert color_for["Z"] == "#0000ff" + # 1 unmapped category (Y) -> _AUSPICE_PALETTE[1] -> single color. + assert color_for["Y"] == _color._AUSPICE_PALETTE[1][0] + + +def test_color_tree_by_legend_title_uses_auspice_meta_title(): + root = _load(_attr_auspice()) + meta = { + "colorings": [{"key": "subclade", "type": "categorical", "title": "Subclade"}] + } + m = _color.compute_node_color_values(root, "subclade", auspice_meta=meta) + assert m.legend_title == "Subclade" + + +def test_color_tree_by_genotype_ignores_auspice_scale(): + """Even if meta.colorings happens to have a 'genotype' entry, the + genotype path doesn't consult it — colors come from `_AUSPICE_PALETTE` + sized to the category count.""" + root = _load(_genotype_auspice()) + meta = { + "colorings": [ + { + "key": "genotype", + "type": "categorical", + "scale": [["K158", "#ff0000"]], + }, + ], + } + m = _color.compute_node_color_values(root, "genotype:HA1:158", auspice_meta=meta) + color_for = dict(zip(m.domain, m.range_)) + # K158 should NOT pick up the Auspice color; it gets a palette slot. + assert color_for["K158"] != "#ff0000" + n_real = sum(1 for c in m.domain if c != "unknown") + assert color_for["K158"] in _color._AUSPICE_PALETTE[n_real] + + +def test_color_tree_by_no_auspice_meta_uses_default_palette(): + """No `auspice_meta` and a node-attr spec -> all real categories come + from `_AUSPICE_PALETTE[N]` where N is the number of real categories.""" + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade", auspice_meta=None) + n_real = sum(1 for c in m.domain if c != "unknown") + palette_n = _color._AUSPICE_PALETTE[n_real] + for cat, col in zip(m.domain, m.range_): + if cat == "unknown": + assert col == "#888888" + else: + assert col in palette_n + + +def test_color_tree_by_categories_sorted_by_descending_frequency(): + """Categories should be ordered by descending count so the most common + one ends up at index 0 (where Auspice's per-N palette puts its + deepest-blue start).""" + # Build a tree where subclade frequencies are X=4 (ROOT, INT_LEFT, A, B), + # Y=2 (INT_RIGHT, C), Z=1 (D). After sorting by descending count we + # expect domain[:3] == ["X", "Y", "Z"]. + root = _load(_attr_auspice()) + m = _color.compute_node_color_values(root, "subclade") + real = [c for c in m.domain if c != "unknown"] + assert real == ["X", "Y", "Z"] + + +def test_color_tree_by_default_palette_matches_auspice_for_six_categories(): + """For a 6-category attribute, the resolved range must equal Auspice's + `colors[6]` exactly. Pins the visual match against Nextstrain.""" + # Build a synthetic tree with 6 categories at distinct frequencies so + # ordering is unambiguous. + d = { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0, "clade": {"value": "C1"}}, + "children": [ + { + "name": f"tip_{i}", + "node_attrs": {"div": 0.01 * i, "clade": {"value": v}}, + } + # Frequencies: C1=6, C2=5, C3=4, C4=3, C5=2, C6=1 -> total 21 + for i, v in enumerate( + ["C1"] * 5 # plus the one on ROOT, total C1=6 + + ["C2"] * 5 + + ["C3"] * 4 + + ["C4"] * 3 + + ["C5"] * 2 + + ["C6"] + ) + ], + }, + } + root = _load(d) + m = _color.compute_node_color_values(root, "clade") + assert m.domain == ["C1", "C2", "C3", "C4", "C5", "C6"] + assert tuple(m.range_) == _color._AUSPICE_PALETTE[6] + + +def test_color_tree_by_gray_reserved_for_unknown(): + """For a tree with non-missing categories plus 'unknown', gray (#888888) + must appear *only* at the 'unknown' slot — never inside the per-N + Auspice palette so a fallback-mapped category cannot collide with it.""" + d = _attr_auspice() + # Drop subclade off one tip so "unknown" enters the legend. + d["tree"]["children"][0]["children"][0]["node_attrs"].pop("subclade") + root = _load(d) + m = _color.compute_node_color_values(root, "subclade") + gray_positions = [i for i, c in enumerate(m.range_) if c == "#888888"] + assert len(gray_positions) == 1 + assert m.domain[gray_positions[0]] == "unknown" + # And gray is not anywhere in the Auspice palette. + for entry in _color._AUSPICE_PALETTE: + assert "#888888" not in entry + assert "#7f7f7f" not in entry + + +# ----------------------------------------------------------------------------- +# plot()-level: encoding placement, legend orient, default-none +# ----------------------------------------------------------------------------- + + +def _vertical_chart(strains: list[str]) -> alt.Chart: + df = pd.DataFrame({"strain": strains, "titer": [1.0, 2.0, 4.0, 8.0]}) + return ( + alt.Chart(df) + .mark_circle() + .encode(x="titer:Q", y=alt.Y("strain:N")) + .properties(width=200, height=200) + ) + + +def _kw(): + return dict( + chart_strain_field="strain", tree_strain_field="name", branch_length="div" + ) + + +def _find_color_encodings(node: Any) -> list[tuple[str, dict]]: + """Recursively find every encoding-block 'color' on the spec, returning + (json-pointer-ish path, encoding dict).""" + hits: list[tuple[str, dict]] = [] + + def walk(o: Any, path: str) -> None: + if isinstance(o, dict): + enc = o.get("encoding") + if isinstance(enc, dict) and "color" in enc: + hits.append((path, enc["color"])) + for k, v in o.items(): + walk(v, f"{path}.{k}") + elif isinstance(o, list): + for i, v in enumerate(o): + walk(v, f"{path}[{i}]") + + walk(node, "") + return hits + + +def _tree_panel_color_encodings(out) -> list[dict]: + """Color encodings on the tree panel only (which is hconcat[0] for our + 'left'-default vertical layout). Filtered to those backed by the + 'color_value' field — pre-existing chart colors are not.""" + spec = out.to_dict() + panels = spec.get("hconcat") or spec.get("vconcat") or [] + assert panels + tree_panel = panels[0] + return [ + enc + for _, enc in _find_color_encodings(tree_panel) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + + +def test_color_tree_by_default_none_no_color_encoding(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + ) + encs = _tree_panel_color_encodings(out) + assert encs == [] + + +def test_color_tree_by_legend_orient_bottom(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + assert len(encs) >= 1 + for enc in encs: + legend = enc.get("legend") or {} + assert legend.get("orient") == "bottom" + + +def test_color_tree_by_legend_title_is_spec_string(): + # No auspice_meta.colorings.title -> falls back to the literal spec. + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + assert encs[0]["legend"]["title"] == "subclade" + + +def test_color_tree_by_branches_and_tips_share_field(): + """The mark_rule (branches) and mark_circle (tips) both reference the + same `color_value:N` field, so Altair collapses them into one legend.""" + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + # Two encodings: one on the seg_df rule, one on the tips_df circle. + assert len(encs) == 2 + fields = {enc["field"] for enc in encs} + assert fields == {"color_value"} + + +def test_color_tree_by_legend_hides_internal_only_unknown(): + """When only an internal node lacks the attribute, the rendered spec's + legend.values must be set and must not contain 'unknown'. Internal-node + branches still render gray via the unchanged scale.""" + d = _attr_auspice() + d["tree"]["node_attrs"].pop("subclade") # drop subclade off ROOT only + out = tree_annotated_plot.plot( + d, + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + assert len(encs) >= 1 + for enc in encs: + legend = enc.get("legend") or {} + assert "values" in legend + assert "unknown" not in legend["values"] + # Scale still carries "unknown" so the gray rendering works. + assert "unknown" in enc["scale"]["domain"] + + +# ----------------------------------------------------------------------------- +# CLI +# ----------------------------------------------------------------------------- + + +def _write_chart_json(path, chart: alt.Chart) -> None: + chart.save(str(path)) + + +def _write_tree_json(path, tree: dict) -> None: + import json + + path.write_text(json.dumps(tree)) + + +def _run_cli(args, expect_success: bool = True): + from click.testing import CliRunner + + from tree_annotated_plot import cli as cli_module + + runner = CliRunner() + result = runner.invoke(cli_module.main, args, catch_exceptions=False) + if expect_success and result.exit_code != 0: + raise AssertionError(f"CLI exit {result.exit_code}\n{result.output}") + return result + + +def _cli_setup(tmp_path, tree_dict: dict, chart: alt.Chart): + tree_path = tmp_path / "tree.json" + chart_path = tmp_path / "chart.json" + out_path = tmp_path / "out.json" + _write_tree_json(tree_path, tree_dict) + _write_chart_json(chart_path, chart) + return tree_path, chart_path, out_path + + +def test_cli_color_tree_by_subclade(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "subclade", + ] + ) + assert out_path.exists() + import json + + spec = json.loads(out_path.read_text()) + encs = [ + enc + for _, enc in _find_color_encodings(spec.get("hconcat", [{}])[0]) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + assert encs + + +def test_cli_color_tree_by_genotype(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, + _genotype_auspice(), + _vertical_chart(["tip_A", "tip_B", "tip_C", "tip_D"]), + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "genotype:HA1:158", + ] + ) + assert out_path.exists() + + +def test_cli_color_tree_by_haplotype(tmp_path): + """Comma in `genotype:HA1:158,189` must survive click's argument parsing.""" + tree_path, chart_path, out_path = _cli_setup( + tmp_path, + _haplotype_auspice(), + _vertical_chart(["tip_A", "tip_B", "tip_C", "tip_D"]), + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "genotype:HA1:158,189", + ] + ) + assert out_path.exists() + import json + + spec = json.loads(out_path.read_text()) + encs = [ + enc + for _, enc in _find_color_encodings(spec.get("hconcat", [{}])[0]) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + assert encs + # Domain should contain at least one slash-joined haplotype label. + domain = encs[0].get("scale", {}).get("domain", []) + assert any("/" in cat for cat in domain)