diff --git a/ultraplot/axes/_formatting.py b/ultraplot/axes/_formatting.py new file mode 100644 index 000000000..2bc292498 --- /dev/null +++ b/ultraplot/axes/_formatting.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Shared metadata for axis formatting keyword routing and persistence. +""" + +import inspect + +_AXIS_STYLE_FIELD_TEMPLATES = { + "color": ("{axis}color", "color"), + "linewidth": ("{axis}linewidth", "linewidth"), + "rotation": ("{axis}rotation", "rotation"), + "spineloc": ("{axis}spineloc", "{axis}loc"), + "tickloc": ("{axis}tickloc",), + "ticklabelloc": ("{axis}ticklabelloc",), + "labelloc": ("{axis}labelloc",), + "offsetloc": ("{axis}offsetloc",), + "grid": ("{axis}grid",), + "gridminor": ("{axis}gridminor",), + "gridcolor": ("{axis}gridcolor", "gridcolor"), + "tickdir": ("{axis}tickdir", "tickdir"), + "tickcolor": ("{axis}tickcolor", "tickcolor"), + "ticklen": ("{axis}ticklen", "ticklen"), + "ticklenratio": ("{axis}ticklenratio", "ticklenratio"), + "tickwidth": ("{axis}tickwidth", "tickwidth"), + "tickwidthratio": ("{axis}tickwidthratio", "tickwidthratio"), + "ticklabeldir": ("{axis}ticklabeldir", "ticklabeldir"), + "ticklabelpad": ("{axis}ticklabelpad",), + "ticklabelcolor": ("{axis}ticklabelcolor", "ticklabelcolor"), + "ticklabelsize": ("{axis}ticklabelsize", "ticklabelsize"), + "ticklabelweight": ("{axis}ticklabelweight", "ticklabelweight"), + "labelpad": ("{axis}labelpad",), + "labelcolor": ("{axis}labelcolor", "labelcolor"), + "labelsize": ("{axis}labelsize", "labelsize"), + "labelweight": ("{axis}labelweight", "labelweight"), +} + + +def _dedupe(items): + return tuple(dict.fromkeys(items)) + + +GENERIC_AXIS_FORMAT_KEYS = _dedupe( + name + for names in _AXIS_STYLE_FIELD_TEMPLATES.values() + for name in names + if "{axis}" not in name +) + + +CARTESIAN_PARENT_FILTER_KEYS = GENERIC_AXIS_FORMAT_KEYS + ( + "label_kw", + "scale_kw", + "locator_kw", + "formatter_kw", + "minorlocator_kw", +) + + +def get_axis_style_fields(axis): + """ + Return the parameter names used to store explicit style overrides. + """ + return { + field: tuple(name.format(axis=axis) for name in names) + for field, names in _AXIS_STYLE_FIELD_TEMPLATES.items() + } + + +def _signature_param_names(*funcs): + names = [] + for func in funcs: + if isinstance(func, inspect.Signature): + sig = func + elif callable(func): + sig = inspect.signature(func) + elif func is None: + continue + else: + raise RuntimeError(f"Internal error. Invalid function {func!r}.") + names.extend(sig.parameters) + return set(names) + + +def pop_axis_format_kwargs(kwargs, *funcs): + """ + Pop axis-format kwargs so they survive rc parsing. + + Returns + ------- + tuple(dict, dict) + The signature-defined keyword arguments and the generic alias keyword + arguments that are not represented in the stored signatures. + """ + signature_keys = _signature_param_names(*funcs) + signature_kwargs = {} + generic_kwargs = {} + for key in tuple(kwargs): + if key in GENERIC_AXIS_FORMAT_KEYS: + generic_kwargs[key] = kwargs.pop(key) + elif key in signature_keys: + signature_kwargs[key] = kwargs.pop(key) + return signature_kwargs, generic_kwargs diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 696639beb..28b493eba 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -4,6 +4,7 @@ """ import copy +import functools import inspect from dataclasses import dataclass, field from typing import Any, Dict, Optional, Tuple, Union @@ -20,6 +21,7 @@ from ..config import rc from ..internals import ( _not_none, + _pop_params, _pop_rc, _version_mpl, docstring, @@ -28,6 +30,11 @@ warnings, ) from ..utils import units +from ._formatting import ( + CARTESIAN_PARENT_FILTER_KEYS, + get_axis_style_fields, + pop_axis_format_kwargs, +) from . import plot, shared __all__ = ["CartesianAxes"] @@ -431,6 +438,8 @@ def __init__(self, *args, **kwargs): self._yaxis_current_rotation = "horizontal" self._xaxis_isdefault_rotation = True # whether to auto rotate the axis self._yaxis_isdefault_rotation = True + self._xaxis_style_state = {} + self._yaxis_style_state = {} super().__init__(*args, **kwargs) # Apply default formatter @@ -447,6 +456,37 @@ def __init__(self, *args, **kwargs): self._dualy_funcscale = None self._dualy_prevstate = None + def _get_axis_style_state(self, axis): + """ + Return the cached explicit style overrides for this axis. + """ + return getattr(self, f"_{axis}axis_style_state") + + def _merge_axis_style_state(self, axis, params): + """ + Merge the current explicit style overrides with the cached overrides. + """ + state = self._get_axis_style_state(axis).copy() + explicit_keys = set(params.get("_explicit_format_keys", ())) + for field, names in get_axis_style_fields(axis).items(): + if any(name in explicit_keys for name in names) and all( + params.get(name, None) is None for name in names + ): + state.pop(field, None) + continue + value = _not_none(*(params.get(name) for name in names)) + if value is not None: + state[field] = value + return state + + def _set_axis_style_state(self, axis, params): + """ + Cache the explicit style overrides for this axis. + """ + setattr( + self, f"_{axis}axis_style_state", self._merge_axis_style_state(axis, params) + ) + def _apply_axis_sharing(self): """ Enforce the "shared" axis labels and axis tick labels. If this is not @@ -1204,13 +1244,12 @@ def _format_axis(self, s: str, config: _AxisFormatConfig, fixticks: bool): self.margins(**{s: config.margin}) # Axis spine settings - # NOTE: This sets spine-specific color and linewidth settings. For - # non-specific settings _update_background is called in Axes.format() self._update_spines(s, loc=config.spineloc, bounds=config.bounds) - self._update_background( + self._update_frame( s, edgecolor=config.color, linewidth=config.linewidth, + tickcolor=config.tickcolor, tickwidth=tickwidth, tickwidthratio=config.tickwidthratio, ) @@ -1297,27 +1336,84 @@ def _resolve_axis_format(self, axis, params, rc_kw): Resolve formatting parameters for a single axis (x or y). """ p = params - - # Color resolution - color = p.get("color") - axis_color = _not_none(p.get(f"{axis}color"), color) + prev = self._merge_axis_style_state(axis, p) # Helper to get axis-specific or generic param def get(name): - return p.get(f"{axis}{name}") + return _not_none(p.get(f"{axis}{name}"), p.get(name)) + + # Color resolution + axis_color_arg = prev.get("color", None) + axis_color = _not_none( + axis_color_arg, + rc.find("axes.edgecolor", context=True), + rc["axes.edgecolor"], + ) + linewidth = _not_none( + prev.get("linewidth", None), + rc.find("axes.linewidth", context=True), + rc["axes.linewidth"], + ) # Resolve colors tickcolor = get("tickcolor") if "tick.color" not in rc_kw: - tickcolor = _not_none(tickcolor, axis_color) + tickcolor = _not_none( + prev.get("tickcolor", None), + axis_color_arg, + rc.find(f"{axis}tick.color", context=True), + rc[f"{axis}tick.color"], + ) ticklabelcolor = get("ticklabelcolor") if "tick.labelcolor" not in rc_kw: - ticklabelcolor = _not_none(ticklabelcolor, axis_color) + ticklabelcolor = _not_none( + prev.get("ticklabelcolor", None), + axis_color_arg, + ) labelcolor = get("labelcolor") if "label.color" not in rc_kw: - labelcolor = _not_none(labelcolor, axis_color) + labelcolor = _not_none( + prev.get("labelcolor", None), + axis_color_arg, + ) + + ticklen = _not_none( + get("ticklen"), + prev.get("ticklen", None), + rc.find("tick.len", context=True), + rc["tick.len"], + ) + ticklenratio = _not_none( + get("ticklenratio"), + prev.get("ticklenratio", None), + rc.find("tick.lenratio", context=True), + rc["tick.lenratio"], + ) + tickwidth = get("tickwidth") + tickwidth = _not_none( + prev.get("tickwidth", None), + prev.get("linewidth", None), + rc.find("tick.width", context=True), + rc["tick.width"], + ) + tickwidthratio = _not_none( + get("tickwidthratio"), + prev.get("tickwidthratio", None), + rc.find("tick.widthratio", context=True), + rc["tick.widthratio"], + ) + ticklabelsize = prev.get("ticklabelsize", None) + ticklabelweight = prev.get("ticklabelweight", None) + labelsize = prev.get("labelsize", None) + labelweight = prev.get("labelweight", None) + grid = prev.get("grid", None) + gridminor = prev.get("gridminor", None) + gridcolor = prev.get("gridcolor", None) + rotation = prev.get("rotation", None) + ticklabelpad = prev.get("ticklabelpad", None) + labelpad = prev.get("labelpad", None) # Flexible keyword args margin = _not_none( @@ -1325,7 +1421,8 @@ def get(name): ) tickdir = _not_none( - get("tickdir"), rc.find(f"{axis}tick.direction", context=True) + prev.get("tickdir", None), + rc.find(f"{axis}tick.direction", context=True), ) locator = _not_none(get("locator"), p.get(f"{axis}ticks")) @@ -1345,31 +1442,32 @@ def get(name): tickminor = _not_none( tickminor, + prev.get("tickminor", None), tickminor_default, rc.find(f"{axis}tick.minor.visible", context=True), ) # Tick label dir logic - ticklabeldir = p.get("ticklabeldir") + ticklabeldir = prev.get("ticklabeldir", None) axis_ticklabeldir = _not_none(get("ticklabeldir"), ticklabeldir) tickdir = _not_none(tickdir, axis_ticklabeldir) # Spine locations loc = get("loc") - spineloc = get("spineloc") + spineloc = prev.get("spineloc", None) spineloc = _not_none(loc, spineloc) # Spine side inference side = self._get_spine_side(axis, spineloc) - tickloc = get("tickloc") + tickloc = prev.get("tickloc", None) if side is not None and side not in ("zero", "center", "both"): tickloc = _not_none(tickloc, side) # Infer other locations - ticklabelloc = get("ticklabelloc") - labelloc = get("labelloc") - offsetloc = get("offsetloc") + ticklabelloc = prev.get("ticklabelloc", None) + labelloc = prev.get("labelloc", None) + offsetloc = prev.get("offsetloc", None) if tickloc != "both": ticklabelloc = _not_none(ticklabelloc, tickloc) @@ -1396,16 +1494,42 @@ def get(name): val = p.get(f"{axis}max") case "color": val = axis_color + case "linewidth": + val = linewidth case "tickcolor": val = tickcolor + case "ticklen": + val = ticklen + case "ticklenratio": + val = ticklenratio + case "tickwidth": + val = tickwidth + case "tickwidthratio": + val = tickwidthratio case "ticklabelcolor": val = ticklabelcolor + case "ticklabelsize": + val = ticklabelsize + case "ticklabelweight": + val = ticklabelweight case "labelcolor": val = labelcolor + case "labelsize": + val = labelsize + case "labelweight": + val = labelweight case "margin": val = margin case "tickdir": val = tickdir + case "grid": + val = grid + case "gridminor": + val = gridminor + case "gridcolor": + val = gridcolor + case "rotation": + val = rotation case "locator": val = locator case "minorlocator": @@ -1416,6 +1540,8 @@ def get(name): val = tickminor case "ticklabeldir": val = axis_ticklabeldir + case "ticklabelpad": + val = ticklabelpad case "spineloc": val = spineloc case "tickloc": @@ -1426,6 +1552,8 @@ def get(name): val = labelloc case "offsetloc": val = offsetloc + case "labelpad": + val = labelpad case _: # Direct mapping (e.g. xlinewidth -> linewidth) val = get(field) @@ -1569,11 +1697,26 @@ def format( or `datetime.datetime` array as the x or y axis coordinate, the axis ticks and tick labels will be automatically formatted as dates. """ + explicit_format_keys = set(kwargs) + explicit_format_keys.update(kwargs.pop("_explicit_format_keys", ())) + signature_axis_kwargs, generic_axis_kwargs = pop_axis_format_kwargs( + kwargs, self._format_signatures[CartesianAxes] + ) + explicit_format_keys.update(signature_axis_kwargs) + explicit_format_keys.update(generic_axis_kwargs) rc_kw, rc_mode = _pop_rc(kwargs) + kwargs.update(signature_axis_kwargs) + kwargs.update(generic_axis_kwargs) + base_kwargs = kwargs.copy() + _pop_params(base_kwargs, self._format_signatures[CartesianAxes]) + for key in CARTESIAN_PARENT_FILTER_KEYS: + base_kwargs.pop(key, None) + with rc.context(rc_kw, mode=rc_mode): # Resolve parameters for x and y axes # We capture locals() to pass all named arguments to the helper params = locals() + params["_explicit_format_keys"] = explicit_format_keys params.update(kwargs) # Include any extras in kwargs x_config = self._resolve_axis_format("x", params, rc_kw) @@ -1582,6 +1725,8 @@ def format( # Format axes self._format_axis("x", x_config, fixticks=fixticks) self._format_axis("y", y_config, fixticks=fixticks) + self._set_axis_style_state("x", params) + self._set_axis_style_state("y", params) if rc.find("formatter.log", context=True): if ( @@ -1603,10 +1748,9 @@ def format( ): self._update_formatter("y", "log") - # Parent format method if aspect is not None: self.set_aspect(aspect) - super().format(rc_kw=rc_kw, rc_mode=rc_mode, **kwargs) + super().format(rc_kw=rc_kw, rc_mode=rc_mode, **base_kwargs) @docstring._snippet_manager def altx(self, **kwargs): @@ -1678,10 +1822,24 @@ def get_tightbbox(self, renderer, *args, **kwargs): return super().get_tightbbox(renderer, *args, **kwargs) +def _capture_explicit_format_keys(func): + """ + Preserve raw keyword names before Python binds them to the format signature. + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + kwargs.setdefault("_explicit_format_keys", set(kwargs)) + return func(self, *args, **kwargs) + + return wrapper + + # tmp # Apply signature obfuscation after storing previous signature # NOTE: This is needed for __init__, altx, and alty CartesianAxes._format_signatures[CartesianAxes] = inspect.signature( CartesianAxes.format ) # noqa: E501 +CartesianAxes.format = _capture_explicit_format_keys(CartesianAxes.format) CartesianAxes.format = docstring._obfuscate_kwargs(CartesianAxes.format) diff --git a/ultraplot/axes/polar.py b/ultraplot/axes/polar.py index bf62e010c..4ced34d36 100644 --- a/ultraplot/axes/polar.py +++ b/ultraplot/axes/polar.py @@ -282,6 +282,34 @@ def format( rc_kw, rc_mode = _pop_rc(kwargs) labelcolor = _not_none(labelcolor, kwargs.get("color", None)) with rc.context(rc_kw, mode=rc_mode): + edgecolor = _not_none( + kwargs.get("color", None), + rc.find("axes.edgecolor", context=True), + rc["axes.edgecolor"], + ) + linewidth = _not_none( + kwargs.get("linewidth", None), + rc.find("axes.linewidth", context=True), + rc["axes.linewidth"], + ) + tickcolor = _not_none( + kwargs.get("tickcolor", None), + kwargs.get("color", None), + rc.find("xtick.color", context=True), + rc["xtick.color"], + ) + tickwidth = _not_none( + kwargs.get("tickwidth", None), + kwargs.get("linewidth", None) and linewidth, + rc.find("tick.width", context=True), + rc["tick.width"], + ) + tickwidthratio = _not_none( + kwargs.get("tickwidthratio", None), + rc.find("tick.widthratio", context=True), + rc["tick.widthratio"], + ) + # Not mutable default args thetalocator_kw = thetalocator_kw or {} thetaminorlocator_kw = thetaminorlocator_kw or {} @@ -320,6 +348,23 @@ def format( if thetadir is not None: self.set_theta_direction(thetadir) + # Polar frame styling used to come from the shared background helper. + # Apply it explicitly now that patch and frame styling are separated. + self._update_frame( + "x", + edgecolor=edgecolor, + linewidth=linewidth, + tickcolor=tickcolor, + tickwidth=tickwidth, + tickwidthratio=tickwidthratio, + ) + self._update_frame( + "y", + tickcolor=tickcolor, + tickwidth=tickwidth, + tickwidthratio=tickwidthratio, + ) + # Loop over axes for ( x, diff --git a/ultraplot/axes/shared.py b/ultraplot/axes/shared.py index 6b66c6219..dadc33b77 100644 --- a/ultraplot/axes/shared.py +++ b/ultraplot/axes/shared.py @@ -40,41 +40,55 @@ def _min_max_lim(key, min_=None, max_=None, lim=None): max_ = _not_none(**{f"{key}max": max_, f"{key}lim_1": lim[1]}) return min_, max_ - def _update_background(self, x=None, tickwidth=None, tickwidthratio=None, **kwargs): + def _update_background(self, **kwargs): """ - Update the background patch and spines. + Update the background patch. """ - # Update the background patch kw_face, kw_edge = rc._get_background_props(**kwargs) self.patch.update(kw_face) - if x is None: - opts = self.spines - elif x == "x": - opts = ("bottom", "top", "inner", "polar") - else: - opts = ("left", "right", "start", "end") - for opt in opts: - self.spines.get(opt, {}).update(kw_edge) + return kw_face, kw_edge - # Update the tick colors - axis = "both" if x is None else x - x = _not_none(x, "x") - obj = getattr(self, x + "axis") - edgecolor = kw_edge.get("edgecolor", None) + def _update_frame( + self, + x, + *, + edgecolor=None, + linewidth=None, + tickcolor=None, + tickwidth=None, + tickwidthratio=None, + ): + """ + Update the axis frame, including spines and tick line appearance. + """ + opts = ( + ("bottom", "top", "inner", "polar") + if x == "x" + else ( + "left", + "right", + "start", + "end", + ) + ) + kw_edge = {"capstyle": "projecting"} if edgecolor is not None: - self.tick_params(axis=axis, which="both", color=edgecolor) + kw_edge["edgecolor"] = edgecolor + if linewidth is not None: + kw_edge["linewidth"] = linewidth + if len(kw_edge) > 1: + for opt in opts: + self.spines.get(opt, {}).update(kw_edge) + + obj = getattr(self, x + "axis") + if tickcolor is None: + tickcolor = edgecolor + if tickcolor is not None: + self.tick_params(axis=x, which="both", color=tickcolor) # Update the tick widths - # NOTE: Only use 'linewidth' if it was explicitly passed. Do not - # include 'linewidth' inferred from rc['axes.linewidth'] setting. kwmajor = getattr(obj, "_major_tick_kw", {}) # graceful fallback if API changes kwminor = getattr(obj, "_minor_tick_kw", {}) - if "linewidth" in kwargs: - tickwidth = _not_none(tickwidth, kwargs["linewidth"]) - tickwidth = _not_none(tickwidth, rc.find("tick.width", context=True)) - tickwidthratio = _not_none( - tickwidthratio, rc.find("tick.widthratio", context=True) - ) # noqa: E501 tickwidth_prev = kwmajor.get("width", rc[x + "tick.major.width"]) if tickwidth_prev == 0: tickwidthratio_prev = rc["tick.widthratio"] # no other way of knowing @@ -92,7 +106,7 @@ def _update_background(self, x=None, tickwidth=None, tickwidthratio=None, **kwar elif which == "minor": tickwidthratio = _not_none(tickwidthratio, tickwidthratio_prev) kwticks["width"] *= tickwidthratio - self.tick_params(axis=axis, which=which, **kwticks) + self.tick_params(axis=x, which=which, **kwticks) def _update_ticks( self, diff --git a/ultraplot/figure.py b/ultraplot/figure.py index c99adb0ab..cff2be655 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -29,6 +29,7 @@ from typing_extensions import override from . import axes as paxes +from .axes._formatting import pop_axis_format_kwargs from . import constructor from . import gridspec as pgridspec from . import legend as plegend @@ -3623,19 +3624,14 @@ def format( # Initiate context block axs = axs or self._subplot_dict.values() skip_axes = kwargs.pop("skip_axes", False) # internal keyword arg - # Preserve explicit projection-specific format keywords that also happen to - # be valid rc aliases (e.g. GeoAxes/PolarAxes `labelsize`). Otherwise - # `_pop_rc()` removes them before the per-axes format dispatch below. - original_kwargs = kwargs.copy() - axis_param_names = set() - for ax in axs: - for cls, sig in paxes.Axes._format_signatures.items(): - if isinstance(ax, cls): - axis_param_names.update(sig.parameters) - axis_param_names.discard("self") + explicit_format_keys = set(kwargs) + signature_axis_kwargs, generic_axis_kwargs = pop_axis_format_kwargs( + kwargs, *paxes.Axes._format_signatures.values() + ) + explicit_format_keys.update(signature_axis_kwargs) + explicit_format_keys.update(generic_axis_kwargs) rc_kw, rc_mode = _pop_rc(kwargs) - for key in axis_param_names & original_kwargs.keys(): - kwargs.setdefault(key, original_kwargs[key]) + kwargs.update(signature_axis_kwargs) with rc.context(rc_kw, mode=rc_mode): # Update background patch kw = rc.fill({"facecolor": "figure.facecolor"}, context=True) @@ -3722,7 +3718,18 @@ def _axis_has_label_text(ax, axis): if kw.get("ylabel") is not None and self._has_share_label_groups("y"): if _axis_has_share_label_text(ax, "y") or _axis_has_label_text(ax, "y"): kw.pop("ylabel", None) - ax.format(rc_kw=rc_kw, rc_mode=rc_mode, skip_figure=True, **kw, **kwargs) + explicit_kw = {} + if isinstance(ax, paxes.CartesianAxes): + explicit_kw["_explicit_format_keys"] = explicit_format_keys + ax.format( + rc_kw=rc_kw, + rc_mode=rc_mode, + skip_figure=True, + **explicit_kw, + **kw, + **kwargs, + **generic_axis_kwargs, + ) ax.number = store_old_number # Warn unused keyword argument(s) kw = { diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 1c7cd3860..29ee8c5bd 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -17,6 +17,7 @@ import numpy as np from . import axes as paxes +from .axes._formatting import pop_axis_format_kwargs from .config import rc from .internals import ( _not_none, @@ -2114,19 +2115,12 @@ def _supports_implicit_label_share(target): else: shared_title_loc = None shared_title_pad = None - # Preserve explicit projection-specific format keywords that also happen to - # be valid rc aliases (e.g. GeoAxes/PloarAxes `labelsize`). Otherwise - # `_pop_rc()` removes them before Figure.format() can delegate to axes. - original_kwargs = kwargs.copy() - axis_param_names = set() - for ax in axes: - for cls, sig in paxes.Axes._format_signatures.items(): - if isinstance(ax, cls): - axis_param_names.update(sig.parameters) - axis_param_names.discard("self") + signature_axis_kwargs, generic_axis_kwargs = pop_axis_format_kwargs( + kwargs, *paxes.Axes._format_signatures.values() + ) rc_kw, rc_mode = _pop_rc(kwargs) - for key in axis_param_names & original_kwargs.keys(): - kwargs.setdefault(key, original_kwargs[key]) + kwargs.update(signature_axis_kwargs) + kwargs.update(generic_axis_kwargs) with rc.context(rc_kw, mode=rc_mode): implicit_share_xlabels = ( is_subset diff --git a/ultraplot/internals/__init__.py b/ultraplot/internals/__init__.py index 487f73c60..16bdd4501 100644 --- a/ultraplot/internals/__init__.py +++ b/ultraplot/internals/__init__.py @@ -340,6 +340,7 @@ def _pop_rc(src, *, ignore_conflicts=True): "tight", "span", ) + kw = src.pop("rc_kw", None) or {} if "mode" in src: src["rc_mode"] = src.pop("mode") diff --git a/ultraplot/tests/test_axes_alt_styles.py b/ultraplot/tests/test_axes_alt_styles.py index 3031168d9..492872db2 100644 --- a/ultraplot/tests/test_axes_alt_styles.py +++ b/ultraplot/tests/test_axes_alt_styles.py @@ -1,7 +1,13 @@ +import matplotlib.colors as mcolors import pytest import ultraplot as uplt +def _all_match_color(colors, expected): + expected = mcolors.to_rgba(expected) + return all(mcolors.to_rgba(color) == expected for color in colors) + + def test_alt_axes_styling_dark_background(): """ Test that applying dark_background style does not leak tick visibility @@ -43,3 +49,208 @@ def test_alt_axes_styling_dark_background(): assert right_ax_left_ticks == 0, "Right axis should NOT have left ticks" assert right_ax_right_ticks > 0, "Right axis should have right ticks" + + assert _all_match_color( + [ + line.get_color() + for line in ax2.yaxis.get_ticklines() + if line.get_visible() + ], + "C1", + ) + assert { + mcolors.to_rgba(ax2.spines[side].get_edgecolor()) + for side in ("left", "right") + if ax2.spines[side].get_visible() + } == {mcolors.to_rgba("C1")} + + +@pytest.mark.parametrize( + ("setup", "format_kwargs", "expected_color", "expected_linewidth"), + [ + ( + lambda ax: ax, + {"ycolor": "C0", "ylinewidth": 3, "ylabel": "Left Axis"}, + "C0", + 3, + ), + ( + lambda ax: ax.alty(color="C1", linewidth=3), + {"ylabel": "Right Axis", "ylim": (0, 1)}, + "C1", + 3, + ), + ], +) +def test_dark_background_preserves_axis_colors_on_reformat( + setup, format_kwargs, expected_color, expected_linewidth +): + with uplt.rc.context(style="dark_background"): + fig, ax = uplt.subplots() + target = setup(ax) + target.format(**format_kwargs) + target.format(ylabel="Updated Label") + + assert _all_match_color( + [label.get_color() for label in target.get_yticklabels()], expected_color + ) + assert mcolors.to_rgba(target.yaxis.label.get_color()) == mcolors.to_rgba( + expected_color + ) + assert _all_match_color( + [ + line.get_color() + for line in target.yaxis.get_ticklines() + if line.get_visible() + ], + expected_color, + ) + assert { + mcolors.to_rgba(target.spines[side].get_edgecolor()) + for side in ("left", "right") + if target.spines[side].get_visible() + } == {mcolors.to_rgba(expected_color)} + assert { + target.spines[side].get_linewidth() + for side in ("left", "right") + if target.spines[side].get_visible() + } == {expected_linewidth} + + +def test_dark_background_updates_unspecified_axis_frame_style(): + fig, ax = uplt.subplots() + + with uplt.rc.context(style="dark_background"): + ax.format(ylabel="Updated Label") + + expected = mcolors.to_rgba(uplt.rc["axes.edgecolor"]) + assert { + mcolors.to_rgba(ax.spines[side].get_edgecolor()) + for side in ("left", "right") + if ax.spines[side].get_visible() + } == {expected} + assert _all_match_color( + [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + expected, + ) + + +@pytest.mark.parametrize( + ("format_kwargs", "getter", "expected_color"), + [ + ( + {"ytickcolor": "red"}, + lambda ax: [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + "red", + ), + ( + {"yticklabelcolor": "blue"}, + lambda ax: [label.get_color() for label in ax.get_yticklabels()], + "blue", + ), + ( + {"ylabelcolor": "green"}, + lambda ax: [ax.yaxis.label.get_color()], + "green", + ), + ], +) +def test_subplots_preserve_explicit_axis_property_overrides_on_reformat( + format_kwargs, getter, expected_color +): + with uplt.rc.context(style="dark_background"): + fig, axs = uplt.subplots() + ax = axs[0] + axs.format(**format_kwargs) + axs.format(ylabel="Updated Label") + + assert _all_match_color(getter(ax), expected_color) + + +def test_subplots_preserve_generic_tickcolor_across_later_axis_color(): + with uplt.rc.context(style="dark_background"): + fig, axs = uplt.subplots() + ax = axs[0] + axs.format(tickcolor="red") + axs.format(ycolor="C1") + + assert _all_match_color( + [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + "red", + ) + assert { + mcolors.to_rgba(ax.spines[side].get_edgecolor()) + for side in ("left", "right") + if ax.spines[side].get_visible() + } == {mcolors.to_rgba("C1")} + + +def test_subplots_apply_generic_labelcolor(): + fig, axs = uplt.subplots() + ax = axs[0] + + axs.format(labelcolor="green") + + assert _all_match_color( + [ax.xaxis.label.get_color(), ax.yaxis.label.get_color()], "green" + ) + + +@pytest.mark.parametrize("format_kwargs", [{"ytickcolor": "red"}, {"tickcolor": "red"}]) +def test_subplots_can_clear_explicit_tickcolor_override(format_kwargs): + with uplt.rc.context(style="dark_background"): + fig, axs = uplt.subplots() + ax = axs[0] + axs.format(**format_kwargs) + clear_kwargs = {key: None for key in format_kwargs} + axs.format(**clear_kwargs) + + assert _all_match_color( + [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + uplt.rc["ytick.color"], + ) + + +@pytest.mark.parametrize("format_kwargs", [{"ytickcolor": "red"}, {"tickcolor": "red"}]) +def test_direct_axes_can_clear_explicit_tickcolor_override(format_kwargs): + with uplt.rc.context(style="dark_background"): + fig = uplt.figure() + ax = fig.subplot(111) + ax.format(**format_kwargs) + clear_kwargs = {key: None for key in format_kwargs} + ax.format(ylabel="Updated Label", **clear_kwargs) + + assert _all_match_color( + [ + line.get_color() + for line in ax.yaxis.get_ticklines() + if line.get_visible() + ], + uplt.rc["ytick.color"], + ) + + +def test_polar_format_updates_frame_style(): + fig = uplt.figure() + ax = fig.subplot(111, proj="polar") + + ax.format(color="C3", linewidth=3) + + assert mcolors.to_rgba(ax.spines["polar"].get_edgecolor()) == mcolors.to_rgba("C3") + assert ax.spines["polar"].get_linewidth() == 3