From 0b7260325a6960945ac170f49ee80fa0d5ca3385 Mon Sep 17 00:00:00 2001 From: samueljwu <56311527+samueljwu@users.noreply.github.com> Date: Thu, 14 May 2026 00:18:44 +0800 Subject: [PATCH] Replace unit handling with Pint OpenwaterHealth/openlifu-python#153 --- pyproject.toml | 1 + src/openlifu/util/units.py | 301 +++++++++++++++++++------------------ tests/test_units.py | 106 +++++++++++++ 3 files changed, 258 insertions(+), 150 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4334255c..f1f6b58b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "numpy<2", "matplotlib", "pandas", + "pint", "scipy", "vtk", "h5py", diff --git a/src/openlifu/util/units.py b/src/openlifu/util/units.py index 6f5ad3e4..f4ca7fef 100644 --- a/src/openlifu/util/units.py +++ b/src/openlifu/util/units.py @@ -1,47 +1,147 @@ from __future__ import annotations +import re + import numpy as np +from pint import UnitRegistry +from pint.errors import DimensionalityError, UndefinedUnitError from xarray import Dataset +ureg = UnitRegistry() +Q_ = ureg.Quantity + +_ANGLE_UNITS = { + "rad", + "radian", + "radians", + "deg", + "degree", + "degrees", + "\u00b0", + "\u00c2\u00b0", +} + +_UNIT_ALIASES = { + "micron": "micrometer", + "microns": "micrometer", + "um": "micrometer", + "\u00b5m": "micrometer", + "\u03bcm": "micrometer", + "sec": "second", + "secs": "second", + "min": "minute", + "mins": "minute", + "hr": "hour", + "hrs": "hour", + "\u00b0": "degree", + "\u00c2\u00b0": "degree", + "cc": "cm^3", + "kgram": "kilogram", + "kgrams": "kilograms", + "amps": "ampere", + "amp": "ampere", +} + +_BASE_UNITS_BY_TYPE = { + "distance": "m", + "area": "m^2", + "volume": "m^3", + "time": "s", + "angle": "rad", + "frequency": "Hz", + "pressure": "Pa", + "watt": "W", +} + + +def _normalize_unit(unit: str) -> str: + unit = unit.strip() + + unit = unit.replace("\u00b2", "^2").replace("\u00b3", "^3") + unit = unit.replace("\u00b5", "u").replace("\u03bc", "u") + + # Fix common typos + unit = re.sub(r"sec(s)?\b", "second", unit, flags=re.IGNORECASE) + unit = re.sub(r"\bmili", "milli", unit, flags=re.IGNORECASE) + unit = re.sub(r"grams?\b", "gram", unit, flags=re.IGNORECASE) + unit = re.sub(r"meters?\b", "meter", unit, flags=re.IGNORECASE) + unit = re.sub(r"\s+", " ", unit) + + normalized_parts = [] + for part in re.split(r"([/*])", unit): + stripped_part = part.strip() + if stripped_part in {"/", "*", ""}: + normalized_parts.append(stripped_part) + continue + + part_key = stripped_part.lower() + normalized_parts.append(_UNIT_ALIASES.get(part_key, _normalize_unit_symbol(stripped_part))) + + normalized = "".join(normalized_parts) + + normalized = re.sub(r"\b([a-zA-Z]+)([23])\b", r"\1^\2", normalized) + return normalized + + +def _normalize_unit_symbol(unit: str) -> str: + unit = re.sub(r"\b([a-zA-Z]+)([23])\b", r"\1^\2", unit) + + for suffix, canonical_suffix in (("hz", "Hz"), ("pa", "Pa")): + suffix_match = re.fullmatch(rf"([A-Za-z]*){suffix}(\^\d+)?", unit, flags=re.IGNORECASE) + if suffix_match: + prefix, power = suffix_match.groups() + return f"{prefix}{canonical_suffix}{power or ''}" + + watt_match = re.fullmatch(r"([A-Za-z]*)w(\^\d+)?", unit, flags=re.IGNORECASE) + if watt_match: + prefix, power = watt_match.groups() + return f"{prefix}W{power or ''}" + + return unit + + +def _quantity(unit: str): + return Q_(1, _normalize_unit(unit)) + def getunittype(unit): - unit = unit.lower() - if unit in ['micron', 'microns']: - return 'distance' - elif unit in ['minute', 'minutes', 'min', 'mins', 'hour', 'hours', 'hr', 'hrs', 'day', 'days', 'd']: - return 'time' - elif unit in ['rad', 'deg', 'radian', 'radians', 'degree', 'degrees', '°']: - return 'angle' - elif 'sec' in unit: - return 'time' - elif 'meter' in unit or 'micron' in unit: - return 'distance' - elif unit.endswith('s'): - return 'time' - elif unit.endswith('m'): - return 'distance' - elif unit.endswith(('m2', 'm^2')): - return 'area' - elif unit.endswith(('m3', 'm^3')): - return 'volume' - elif unit.endswith('hz'): - return 'frequency' - elif unit.endswith('pa'): - return 'pressure' - elif unit.endswith('w'): - return 'watt' - else: - return 'other' + normalized_unit = _normalize_unit(unit) + + if normalized_unit.lower() in _ANGLE_UNITS: + return "angle" + + try: + dim = Q_(1, normalized_unit).dimensionality + except (TypeError, UndefinedUnitError): + return "other" + + if dim == ureg.meter.dimensionality: + return "distance" + if dim == (ureg.meter**2).dimensionality: + return "area" + if dim == (ureg.meter**3).dimensionality: + return "volume" + if dim == ureg.second.dimensionality: + return "time" + if dim == (1 / ureg.second).dimensionality: + return "frequency" + if dim == ureg.pascal.dimensionality: + return "pressure" + if dim == ureg.watt.dimensionality: + return "watt" + + return "other" + def getunitconversion(from_unit, to_unit, unitratio=None, constant=None): if not from_unit: return 1.0 if unitratio is not None and constant is not None: - if '/' not in unitratio: - raise ValueError('Conversion unit ratio must have a \'/\' symbol') + if "/" not in unitratio: + raise ValueError("Conversion unit ratio must have a '/' symbol") - unitn, unitd = unitratio.split('/') + unitn, unitd = unitratio.split("/") type0 = getunittype(from_unit) type1 = getunittype(to_unit) typen = getunittype(unitn) @@ -54,129 +154,30 @@ def getunitconversion(from_unit, to_unit, unitratio=None, constant=None): elif type0 == type1: scl = getunitconversion(from_unit, to_unit) else: - raise ValueError(f'Unit type mismatch {type0} -> ({typen}/{typed}) -> {type1}') + raise ValueError(f"Unit type mismatch {type0} -> ({typen}/{typed}) -> {type1}") else: - slash0 = from_unit.find('/') - slash1 = to_unit.find('/') - - if slash0 != -1 and slash1 != -1: - num0 = from_unit[:slash0] - denom0 = from_unit[slash0+1:] - num1 = to_unit[:slash1] - denom1 = to_unit[slash1+1:] - scl = getunitconversion(num0, num1) / getunitconversion(denom0, denom1) - elif slash0 == -1 and slash1 == -1: + try: + scl = _quantity(from_unit).to(_normalize_unit(to_unit)).magnitude + except DimensionalityError as exc: type0 = getunittype(from_unit) type1 = getunittype(to_unit) - - if type0 != type1: - raise ValueError(f'Unit type mismatch ({type0}) vs ({type1})') - - if type0 == 'other': - if from_unit[-1] != to_unit[-1]: - raise ValueError(f'Cannot convert {from_unit} to {to_unit}') - - i = 0 - while i < min(len(from_unit), len(to_unit)) and from_unit[-i:] == to_unit[-i:]: - type = from_unit[-i:] - i += 1 - - scl0 = getsiscale(from_unit, type) - scl1 = getsiscale(to_unit, type) - scl = scl0 / scl1 - else: - scl0 = getsiscale(from_unit, type0) - scl1 = getsiscale(to_unit, type0) - scl = scl0 / scl1 - else: - raise ValueError(f'Unit ratio mismatch ({from_unit} vs {to_unit})') + raise ValueError(f"Unit type mismatch ({type0}) vs ({type1})") from exc + except UndefinedUnitError as exc: + raise ValueError(f"Cannot convert {from_unit} to {to_unit}") from exc return scl + def getsiscale(unit, type): type = type.lower() - if type in ['distance', 'area', 'volume']: - idx = unit.find('meters') - if idx == -1: - idx = unit.find('meter') - if idx == -1: - if unit.lower() == 'micron': - idx = 6 - else: - idx = unit.rfind('m') - if idx == -1: - idx = len(unit) - - elif type == 'time': - idx = unit.find('seconds') - if idx == -1: - idx = unit.find('second') - if idx == -1: - idx = unit.find('sec') - if idx == -1: - idx = unit.rfind('s') - if idx == -1: - idx = len(unit) - - elif type == 'angle': - idx = len(unit) - - elif type == 'frequency' or type == "pressure": - idx = len(unit) - 2 - - elif type == "watt": - idx = len(unit) - 1 + if type not in _BASE_UNITS_BY_TYPE: + raise ValueError(f"Unknown unit type {type}") - else: - idx = len(unit) - len(type) + 1 - - prefix = unit[:idx] - - if not prefix: - scl = 1.0 - else: - scl = 1.0 - - if prefix == 'pico' or prefix == 'p': - scl = 1.0e-12 - elif prefix == 'nano' or prefix == 'n': - scl = 1.0e-9 - elif prefix == 'micro' or prefix == 'u' or prefix == '\u00b5' or prefix == '\u03bc': - scl = 1.0e-6 - elif prefix == 'milli' or prefix == 'm': - scl = 1.0e-3 - elif prefix == 'centi' or prefix == 'c': - scl = 1.0e-2 - elif prefix == '': - scl = 1.0 - elif prefix == 'kilo' or prefix == 'k': - scl = 1.0e3 - elif prefix == 'mega' or prefix == 'M': - scl = 1.0e6 - elif prefix == 'giga' or prefix == 'G': - scl = 1.0e9 - elif prefix == 'tera' or prefix == 'T': - scl = 1.0e12 - elif prefix == 'min' or prefix == 'minute': - scl = 60.0 - elif prefix == 'hour' or prefix == 'hr': - scl = 60.0 * 60.0 - elif prefix == 'day' or prefix == 'd': - scl = 60.0 * 60.0 * 24.0 - elif prefix == 'rad' or prefix == 'radian' or prefix == 'radians': - scl = 1.0 - elif prefix == 'deg' or prefix == 'degree' or prefix == 'degrees' or prefix == '\u00b0': - scl = 2 * 3.14159265358979323846 / 360 - elif prefix: - raise ValueError(f'Unknown prefix {prefix}') - - if type == 'area': - scl = scl ** 2.0 - elif type == 'volume': - scl = scl ** 3.0 - - return scl + try: + return getunitconversion(unit, _BASE_UNITS_BY_TYPE[type]) + except ValueError as exc: + raise ValueError(f"Unknown prefix {unit}") from exc def rescale_data_arr(data_arr: Dataset, units: str) -> Dataset: @@ -191,9 +192,9 @@ def rescale_data_arr(data_arr: Dataset, units: str) -> Dataset: rescaled: The rescaled xarray to new units. """ rescaled = data_arr.copy(deep=True) - scale = getunitconversion(data_arr.attrs['units'], units) + scale = getunitconversion(data_arr.attrs["units"], units) rescaled.data *= scale - rescaled.attrs['units'] = units + rescaled.attrs["units"] = units return rescaled @@ -212,12 +213,12 @@ def rescale_coords(data_arr: Dataset, units: str) -> Dataset: rescaled = data_arr.copy(deep=True) for coord_key in data_arr.coords: curr_coord_attrs = rescaled[coord_key].attrs - if 'units' in curr_coord_attrs: - curr_coord_units = curr_coord_attrs['units'] + if "units" in curr_coord_attrs: + curr_coord_units = curr_coord_attrs["units"] scale = getunitconversion(curr_coord_units, units) - curr_coord_rescaled = scale*rescaled[coord_key].data + curr_coord_rescaled = scale * rescaled[coord_key].data rescaled = rescaled.assign_coords({coord_key: (coord_key, curr_coord_rescaled, curr_coord_attrs)}) - rescaled[coord_key].attrs['units'] = units + rescaled[coord_key].attrs["units"] = units return rescaled @@ -237,7 +238,7 @@ def get_ndgrid_from_arr(data_arr: Dataset) -> np.ndarray: ordered_key = data_arr[first_data_key].dims all_coord = [] for coord_key in ordered_key: - if 'units' in data_arr[coord_key].attrs: + if "units" in data_arr[coord_key].attrs: all_coord += [data_arr.coords[coord_key].data] ndgrid = np.stack(np.meshgrid(*all_coord, indexing="ij"), axis=-1) diff --git a/tests/test_units.py b/tests/test_units.py index de6942e9..d2ee0814 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -7,6 +7,8 @@ from openlifu.util.units import ( get_ndgrid_from_arr, getsiscale, + getunitconversion, + getunittype, rescale_coords, rescale_data_arr, ) @@ -53,6 +55,110 @@ def test_getsiscale(): assert getsiscale('THz', 'frequency') == 1e12 +@pytest.mark.parametrize( + ("from_unit", "to_unit", "expected"), + [ + ("m", "mm", 1e3), + ("mm", "m", 1e-3), + ("m", "cm", 1e2), + ("cm", "m", 1e-2), + ("m2", "cm2", 1e4), + ("cm2", "m2", 1e-4), + ("m^2", "cm^2", 1e4), + ("cm^2", "m^2", 1e-4), + ("m3", "cm3", 1e6), + ("cm3", "m3", 1e-6), + ("s", "ms", 1e3), + ("ms", "s", 1e-3), + ("sec", "s", 1.0), + ("min", "s", 60.0), + ("hr", "s", 3600.0), + ("day", "s", 86400.0), + ("Hz", "kHz", 1e-3), + ("kHz", "Hz", 1e3), + ("Pa", "kPa", 1e-3), + ("kPa", "Pa", 1e3), + ("W", "mW", 1e3), + ("mW", "W", 1e-3), + ("deg", "rad", np.pi / 180), + ("rad", "deg", 180 / np.pi), + ("micron", "m", 1e-6), + ("um", "m", 1e-6), + ("Pa", "MPa", 1e-6), + ("W/cm^2", "W/m^2", 1e4), + ("mW/cm^2", "W/m^2", 10), + ("MHz", "Hz", 1e6), + ], +) +def test_getunitconversion(from_unit, to_unit, expected): + assert np.allclose(getunitconversion(from_unit, to_unit), expected) + + +@pytest.mark.parametrize( + ("unit", "expected"), + [ + ("mm", "distance"), + ("mm^2", "area"), + ("mm^3", "volume"), + ("s", "time"), + ("MHz", "frequency"), + ("MPa", "pressure"), + ("W", "watt"), + ("deg", "angle"), + ("micron", "distance"), + ("microns", "distance"), + ("meter", "distance"), + ("meters", "distance"), + ("m2", "area"), + ("m3", "volume"), + ("sec", "time"), + ("minute", "time"), + ("minutes", "time"), + ("min", "time"), + ("mins", "time"), + ("hour", "time"), + ("hours", "time"), + ("hr", "time"), + ("hrs", "time"), + ("day", "time"), + ("days", "time"), + ("d", "time"), + ("rad", "angle"), + ("\u00b0", "angle"), + ], +) +def test_getunittype(unit, expected): + assert getunittype(unit) == expected + + +@pytest.mark.parametrize( + ("from_unit", "to_unit", "expected"), + [ + ("microns", "m", 1e-6), + ("cc", "m^3", 1e-6), + ("msec", "s", 1e-3), + ("usec", "s", 1e-6), + ("miliPa", "Pa", 1e-3), + ("W/cm^2", "W/m^2", 1e4), + ], +) +def test_getunitconversion_pint_improvements(from_unit, to_unit, expected): + assert np.allclose(getunitconversion(from_unit, to_unit), expected) + + +@pytest.mark.parametrize( + ("unit", "expected"), + [ + ("cc", "volume"), + ("m/s", "other"), + ("W/cm^2", "other"), + ("amps", "other"), + ], +) +def test_getunittype_pint_improvements(unit, expected): + assert getunittype(unit) == expected + + def test_rescale_data_arr(example_xarr: Dataset): """Test that an xarray data can be correctly rescaled.""" expected_p = 1e-6 * example_xarr['p'].data