Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import astropy.units as u

from uploader.app.lib.expression import parse


def _sample_values() -> tuple[dict[str, float], dict[str, str]]:
values = {
"logd25": 1.5,
"logr25": 0.3,
"e_logd25": 0.05,
"e_logr25": 0.04,
"pa": 190.0,
}
units = {
"logd25": "",
"logr25": "",
"e_logd25": "",
"e_logr25": "",
"pa": "deg",
}
return values, units


def test_isophotal_axis_expressions() -> None:
values, units = _sample_values()
a = parse('3 * 10 ** col("logd25") * arcsec').evaluate(values, units)
assert a.unit == u.arcsec
assert abs(a.value - 94.86832980505137) < 1e-6

e_a = parse('3 * 10 ** col("logd25") * 2.302585093 * e_logd25 * arcsec').evaluate(values, units)
assert e_a.unit == u.arcsec
assert e_a.value > 0

b = parse('3 * 10 ** (col("logd25") - col("logr25")) * arcsec').evaluate(values, units)
assert b.unit == u.arcsec
assert b.value > 0

e_b = parse(
'3 * 10 ** (col("logd25") - col("logr25")) * 2.302585093 '
'* (col("e_logd25") ** 2 + col("e_logr25") ** 2) ** 0.5 * arcsec',
).evaluate(values, units)
assert e_b.unit == u.arcsec
assert e_b.value > 0


def test_position_angle_modulo() -> None:
values, units = _sample_values()
pa = parse('col("pa") % 180.0').evaluate(values, units)
assert pa.unit == u.deg
assert pa.value == 10.0


def test_isophotal_axis_expressions_with_logarithmic_column_units() -> None:
values, units = _sample_values()
for log_unit in ("mag", "dex"):
units_with_log = {**units, "logd25": log_unit, "logr25": log_unit, "e_logd25": log_unit, "e_logr25": log_unit}
a = parse('3 * 10 ** col("logd25") * arcsec').evaluate(values, units_with_log)
assert a.unit == u.arcsec
assert abs(a.value - 94.86832980505137) < 1e-6

e_a = parse('3 * 10 ** col("logd25") * 2.302585093 * e_logd25 * arcsec').evaluate(values, units_with_log)
assert e_a.unit == u.arcsec
assert e_a.value > 0

e_b = parse(
'3 * 10 ** (col("logd25") - col("logr25")) * 2.302585093 '
'* (col("e_logd25") ** 2 + col("e_logr25") ** 2) ** 0.5 * arcsec',
).evaluate(values, units_with_log)
assert e_b.unit == u.arcsec
assert e_b.value > 0


def test_isophotal_axis_expressions_with_hyperleda_units() -> None:
values = {
"logd25": 0.697,
"logr25": 0.13,
"e_logd25": 0.079,
"e_logr25": 0.028,
"pa": 161.14,
}
units = {
"logd25": "dex(0.1 arcmin)",
"logr25": "dex",
"e_logd25": "dex(0.1 arcmin)",
"e_logr25": "dex",
"pa": "deg",
}
a = parse('3 * 10 ** col("logd25") * arcsec').evaluate(values, units)
assert a.unit == u.arcsec
assert a.value > 0

e_a = parse('3 * 10 ** col("logd25") * 2.302585093 * e_logd25 * arcsec').evaluate(values, units)
assert e_a.unit == u.arcsec
assert e_a.value > 0

b = parse('3 * 10 ** (col("logd25") - col("logr25")) * arcsec').evaluate(values, units)
assert b.unit == u.arcsec
assert b.value > 0

e_b = parse(
'3 * 10 ** (col("logd25") - col("logr25")) * 2.302585093 '
'* (col("e_logd25") ** 2 + col("e_logr25") ** 2) ** 0.5 * arcsec',
).evaluate(values, units)
assert e_b.unit == u.arcsec
assert e_b.value > 0


def test_surface_brightness_column_keeps_units() -> None:
values = {"bri25": 23.162}
units = {"bri25": "mag / arcsec2"}
bri25 = parse('col("bri25")').evaluate(values, units)
assert bri25.unit == u.Unit("mag/arcsec2")


def test_bare_column_referenced_in_parse() -> None:
expr = parse("3 * 10 ** col('logd25') * 2.302585093 * e_logd25 * arcsec")
assert expr.referenced_columns == {"logd25", "e_logd25"}
61 changes: 49 additions & 12 deletions uploader/app/lib/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import astropy.constants as const
import astropy.units as u
import numpy as np
from astropy.units.function.core import FunctionUnitBase

COL_FUNCTION = "col"

Expand All @@ -24,10 +25,10 @@
def expression_syntax_help() -> str:
constants = ", ".join(sorted(NAMED_CONSTANTS))
return (
f'Use {COL_FUNCTION}("name") to refer to rawdata columns '
'(e.g. col("a"), col("SMASB22.5"), col("PA-LEDA")).\n'
"Bare identifiers refer to predefined constants.\n"
"Operators: + - * /.\n"
f'Use {COL_FUNCTION}("name") or bare identifiers to refer to rawdata columns '
'(e.g. col("a"), e_logd25).\n'
"Bare identifiers that match predefined constants use those constants.\n"
"Operators: + - * / ** %.\n"
"Functions: sin(x), cos(x) (argument must be an angle).\n"
"Numbers are dimensionless.\n"
f"Available constants: {constants}."
Expand All @@ -38,11 +39,44 @@ def expression_syntax_help() -> str:
type _QuantityUnaryOp = Callable[[u.Quantity], u.Quantity]
type _QuantityFunc = Callable[[u.Quantity], u.Quantity | float]


def _mod(left: u.Quantity, right: u.Quantity) -> u.Quantity:
if not right.unit.is_equivalent(u.dimensionless_unscaled):
raise ValueError("modulo divisor must be dimensionless")
return (left.value % right.value) * left.unit


def _is_logarithmic_column_unit(unit: u.Unit) -> bool:
return unit == u.mag or unit == u.dex or isinstance(unit, FunctionUnitBase)


def _column_quantity(value: float, unit_str: str) -> u.Quantity:
if not unit_str:
return value * u.dimensionless_unscaled
unit = u.Unit(unit_str)
if _is_logarithmic_column_unit(unit):
return value * u.dimensionless_unscaled
return value * unit


def _pow(base: u.Quantity, exp: u.Quantity) -> u.Quantity:
if not exp.unit.is_equivalent(u.dimensionless_unscaled):
exp = float(exp.value) * u.dimensionless_unscaled
try:
return operator.pow(base, exp)
except u.UnitTypeError:
if base.unit.is_equivalent(u.dimensionless_unscaled):
return float(base.value**exp.value) * u.dimensionless_unscaled
raise


_BINOPS: dict[type[ast.operator], _QuantityBinOp] = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: _pow,
ast.Mod: _mod,
}

_UNARYOPS: dict[type[ast.unaryop], _QuantityUnaryOp] = {
Expand Down Expand Up @@ -104,6 +138,11 @@ def visit_Call(self, node: ast.Call) -> None:
for arg in node.args:
self.visit(arg)

def visit_Name(self, node: ast.Name) -> None:
if node.id in NAMED_CONSTANTS or node.id in _FUNCTIONS:
return
self.columns.add(node.id)


@final
class _Evaluator(ast.NodeVisitor):
Expand All @@ -120,7 +159,7 @@ def visit(self, node: ast.AST) -> u.Quantity:
case ast.Call() as call:
return self._call(call)
case ast.Name(id=name):
return self._lookup_constant(name)
return self._lookup_name(name)
case ast.Constant(value=value):
return self._constant(value)
case _:
Expand Down Expand Up @@ -157,18 +196,16 @@ def _call(self, node: ast.Call) -> u.Quantity:
return result
return float(result) * u.dimensionless_unscaled

def _lookup_constant(self, name: str) -> u.Quantity:
def _lookup_name(self, name: str) -> u.Quantity:
constant = NAMED_CONSTANTS.get(name)
if constant is None:
raise ValueError(f"unknown constant {name!r}")
return constant
if constant is not None:
return constant
return self._lookup_column(name)

def _lookup_column(self, name: str) -> u.Quantity:
if name not in self._values:
raise ValueError(f"unknown column {name!r}")
unit_str = self._units.get(name, "")
unit = u.Unit(unit_str) if unit_str else u.dimensionless_unscaled
return self._values[name] * unit
return _column_quantity(self._values[name], self._units.get(name, ""))

def _constant(self, value: object) -> u.Quantity:
if isinstance(value, bool):
Expand Down
Loading
Loading