diff --git a/pixi.toml b/pixi.toml index 385b9f52..ba024968 100644 --- a/pixi.toml +++ b/pixi.toml @@ -104,6 +104,7 @@ test-py310 = { features = ["test", "py310"] } test-py311 = { features = ["test", "py311"] } test-py312 = { features = ["test", "py312"] } test-notebooks = { features = ["test", "notebooks"], solve-group = "test" } +analysis = { features = ["analysis"], solve-group = "analysis" } docs = { features = ["docs"], solve-group = "docs" } typing = { features = ["typing"], solve-group = "typing" } pre-commit = { features = ["pre-commit"], no-default-feature = true } diff --git a/src/virtualship/instruments/adcp.py b/src/virtualship/instruments/adcp.py index 17797a41..b2da6582 100644 --- a/src/virtualship/instruments/adcp.py +++ b/src/virtualship/instruments/adcp.py @@ -1,14 +1,14 @@ +from collections.abc import Callable from dataclasses import dataclass from typing import ClassVar import numpy as np -from parcels import ParticleSet, ScipyParticle, Variable +from parcels import ParticleSet, ScipyParticle from virtualship.instruments.base import Instrument +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType -from virtualship.utils import ( - register_instrument, -) +from virtualship.utils import build_particle_class_from_sensors, register_instrument # ===================================================== # SECTION: Dataclass @@ -23,16 +23,12 @@ class ADCP: # ===================================================== -# SECTION: Particle Class +# SECTION: non-sensor Particle Variables (non-sampling) # ===================================================== +# ADCP has no non-sensor variables, only sensor variables. +_ADCP_NONSENSOR_VARIABLES: list = [] -_ADCPParticle = ScipyParticle.add_variables( - [ - Variable("U", dtype=np.float32, initial=np.nan), - Variable("V", dtype=np.float32, initial=np.nan), - ] -) # ===================================================== # SECTION: Kernels @@ -54,9 +50,13 @@ def _sample_velocity(particle, fieldset, time): class ADCPInstrument(Instrument): """ADCP instrument class.""" + sensor_kernels: ClassVar[dict[SensorType, Callable]] = { + SensorType.VELOCITY: _sample_velocity, + } + def __init__(self, expedition, from_data): """Initialize ADCPInstrument.""" - variables = {"U": "uo", "V": "vo"} + variables = expedition.instruments_config.adcp_config.active_variables() limit_spec = { "spatial": True } # spatial limits; lat/lon constrained to waypoint locations + buffer @@ -93,6 +93,12 @@ def simulate(self, measurements, out_path) -> None: fieldset = self.load_input_data() + # build dynamic particle class from the active sensors + adcp_config = self.expedition.instruments_config.adcp_config + _ADCPParticle = build_particle_class_from_sensors( + adcp_config.sensors, _ADCP_NONSENSOR_VARIABLES, ScipyParticle + ) + bins = np.linspace(MAX_DEPTH, MIN_DEPTH, NUM_BINS) num_particles = len(bins) particleset = ParticleSet.from_list( @@ -108,6 +114,13 @@ def simulate(self, measurements, out_path) -> None: out_file = particleset.ParticleFile(name=out_path, outputdt=np.inf) + # build kernel list from active sensors only + sampling_kernels = [ + self.sensor_kernels[sc.sensor_type] + for sc in adcp_config.sensors + if sc.enabled and sc.sensor_type in self.sensor_kernels + ] + for point in measurements: particleset.lon_nextloop[:] = point.location.lon particleset.lat_nextloop[:] = point.location.lat @@ -116,7 +129,7 @@ def simulate(self, measurements, out_path) -> None: ) particleset.execute( - [_sample_velocity], + sampling_kernels, dt=1, runtime=1, verbose_progress=self.verbose_progress, diff --git a/src/virtualship/instruments/argo_float.py b/src/virtualship/instruments/argo_float.py index 1c697852..8c90cfb2 100644 --- a/src/virtualship/instruments/argo_float.py +++ b/src/virtualship/instruments/argo_float.py @@ -1,21 +1,17 @@ import math +from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta from typing import ClassVar import numpy as np -from parcels import ( - AdvectionRK4, - JITParticle, - ParticleSet, - StatusCode, - Variable, -) +from parcels import AdvectionRK4, JITParticle, ParticleSet, StatusCode, Variable from virtualship.instruments.base import Instrument +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType from virtualship.models.spacetime import Spacetime -from virtualship.utils import register_instrument +from virtualship.utils import build_particle_class_from_sensors, register_instrument # ===================================================== # SECTION: Dataclass @@ -37,25 +33,21 @@ class ArgoFloat: # ===================================================== -# SECTION: Particle Class +# SECTION: non-sensor Particle Variables (non-sampling) # ===================================================== -_ArgoParticle = JITParticle.add_variables( - [ - Variable("cycle_phase", dtype=np.int32, initial=0.0), - Variable("cycle_age", dtype=np.float32, initial=0.0), - Variable("drift_age", dtype=np.float32, initial=0.0), - Variable("salinity", dtype=np.float32, initial=np.nan), - Variable("temperature", dtype=np.float32, initial=np.nan), - Variable("min_depth", dtype=np.float32), - Variable("max_depth", dtype=np.float32), - Variable("drift_depth", dtype=np.float32), - Variable("vertical_speed", dtype=np.float32), - Variable("cycle_days", dtype=np.int32), - Variable("drift_days", dtype=np.int32), - Variable("grounded", dtype=np.int32, initial=0), - ] -) +_ARGO_NONSENSOR_VARIABLES = [ + Variable("cycle_phase", dtype=np.int32, initial=0.0), + Variable("cycle_age", dtype=np.float32, initial=0.0), + Variable("drift_age", dtype=np.float32, initial=0.0), + Variable("min_depth", dtype=np.float32), + Variable("max_depth", dtype=np.float32), + Variable("drift_depth", dtype=np.float32), + Variable("vertical_speed", dtype=np.float32), + Variable("cycle_days", dtype=np.int32), + Variable("drift_days", dtype=np.int32), + Variable("grounded", dtype=np.int32, initial=0), +] # ===================================================== # SECTION: Kernels @@ -118,18 +110,7 @@ def _argo_float_vertical_movement(particle, fieldset, time): particle.grounded = 0 if particle.depth + particle_ddepth >= particle.min_depth: particle_ddepth = particle.min_depth - particle.depth - particle.temperature = ( - math.nan - ) # reset temperature to NaN at end of sampling cycle - particle.salinity = math.nan # idem particle.cycle_phase = 4 - else: - particle.temperature = fieldset.T[ - time, particle.depth, particle.lat, particle.lon - ] - particle.salinity = fieldset.S[ - time, particle.depth, particle.lat, particle.lon - ] elif particle.cycle_phase == 4: # Phase 4: Transmitting at surface until cycletime is reached @@ -153,6 +134,24 @@ def _check_error(particle, fieldset, time): particle.delete() +def _argo_sample_temperature(particle, fieldset, time): + # Phase 3: ascending — sample temperature; NaN otherwise + if particle.cycle_phase == 3 and particle.depth < particle.min_depth: + particle.temperature = fieldset.T[ + time, particle.depth, particle.lat, particle.lon + ] + else: + particle.temperature = math.nan + + +def _argo_sample_salinity(particle, fieldset, time): + # Phase 3: ascending — sample salinity; NaN otherwise + if particle.cycle_phase == 3 and particle.depth < particle.min_depth: + particle.salinity = fieldset.S[time, particle.depth, particle.lat, particle.lon] + else: + particle.salinity = math.nan + + # ===================================================== # SECTION: Instrument Class # ===================================================== @@ -162,9 +161,21 @@ def _check_error(particle, fieldset, time): class ArgoFloatInstrument(Instrument): """ArgoFloat instrument class.""" + sensor_kernels: ClassVar[dict[SensorType, Callable]] = { + SensorType.TEMPERATURE: _argo_sample_temperature, + SensorType.SALINITY: _argo_sample_salinity, + } + def __init__(self, expedition, from_data): """Initialize ArgoFloatInstrument.""" - variables = {"U": "uo", "V": "vo", "S": "so", "T": "thetao"} + sensor_variables = ( + expedition.instruments_config.argo_float_config.active_variables() + ) + variables = { + "U": "uo", + "V": "vo", + **sensor_variables, + } # advection variables (U and V) are always required for argo float simulation; sensor variables come from config spacetime_buffer_size = { "latlon": 3.0, # [degrees] "time": expedition.instruments_config.argo_float_config.lifetime.total_seconds() @@ -215,6 +226,14 @@ def simulate(self, measurements, out_path) -> None: f"{self.__class__.__name__} cannot be deployed in waters shallower than 50m. The following waypoints are too shallow: {shallow_waypoints}." ) + # build dynamic particle class from the active sensors + argo_float_config = self.expedition.instruments_config.argo_float_config + _ArgoParticle = build_particle_class_from_sensors( + argo_float_config.sensors, + _ARGO_NONSENSOR_VARIABLES, + JITParticle, + ) + # define parcel particles argo_float_particleset = ParticleSet( fieldset=fieldset, @@ -241,10 +260,18 @@ def simulate(self, measurements, out_path) -> None: # endtime endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1]) + # build kernel list from active sensors only + sampling_kernels = [ + self.sensor_kernels[sc.sensor_type] + for sc in argo_float_config.sensors + if sc.enabled and sc.sensor_type in self.sensor_kernels + ] + # execute simulation argo_float_particleset.execute( [ _argo_float_vertical_movement, + *sampling_kernels, AdvectionRK4, _keep_at_surface, _check_error, diff --git a/src/virtualship/instruments/base.py b/src/virtualship/instruments/base.py index 2ca1b783..d4e078e6 100644 --- a/src/virtualship/instruments/base.py +++ b/src/virtualship/instruments/base.py @@ -1,11 +1,11 @@ from __future__ import annotations import abc -from collections import OrderedDict +import collections from datetime import timedelta from itertools import pairwise from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar import copernicusmarine import xarray as xr @@ -24,12 +24,24 @@ ) if TYPE_CHECKING: + from virtualship.instruments.sensors import SensorType from virtualship.models import Expedition class Instrument(abc.ABC): """Base class for instruments and their simulation.""" + # all instruments have sensor_kernels dict, mapping SensorType to sampling kernel + sensor_kernels: ClassVar[dict[SensorType, collections.abc.Callable]] + + def __init_subclass__(cls, **kwargs: object) -> None: + """Ensure subclasses define sensor_kernels as class attribute.""" + super().__init_subclass__(**kwargs) + if "sensor_kernels" not in cls.__dict__: + raise TypeError( + f"Instrument subclass '{cls.__name__}' must define 'sensor_kernels' as a class attribute." + ) + def __init__( self, expedition: Expedition, @@ -45,7 +57,7 @@ def __init__( self.expedition = expedition self.from_data = from_data - self.variables = OrderedDict(variables) + self.variables = collections.OrderedDict(variables) self.dimensions = { "lon": "longitude", "lat": "latitude", diff --git a/src/virtualship/instruments/ctd.py b/src/virtualship/instruments/ctd.py index eb780d3e..122e1461 100644 --- a/src/virtualship/instruments/ctd.py +++ b/src/virtualship/instruments/ctd.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta from typing import TYPE_CHECKING, ClassVar @@ -6,11 +7,16 @@ from parcels import JITParticle, ParticleSet, Variable from virtualship.instruments.base import Instrument +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType +from virtualship.utils import ( + add_dummy_UV, + build_particle_class_from_sensors, + register_instrument, +) if TYPE_CHECKING: from virtualship.models.spacetime import Spacetime -from virtualship.utils import add_dummy_UV, register_instrument # ===================================================== # SECTION: Dataclass @@ -28,19 +34,15 @@ class CTD: # ===================================================== -# SECTION: Particle Class +# SECTION: non-sensor Particle Variables (non-sampling) # ===================================================== -_CTDParticle = JITParticle.add_variables( - [ - Variable("salinity", dtype=np.float32, initial=np.nan), - Variable("temperature", dtype=np.float32, initial=np.nan), - Variable("raising", dtype=np.int8, initial=0.0), # bool. 0 is False, 1 is True. - Variable("max_depth", dtype=np.float32), - Variable("min_depth", dtype=np.float32), - Variable("winch_speed", dtype=np.float32), - ] -) +_CTD_NONSENSOR_VARIABLES = [ + Variable("raising", dtype=np.int8, initial=0.0), # bool. 0 is False, 1 is True. + Variable("max_depth", dtype=np.float32), + Variable("min_depth", dtype=np.float32), + Variable("winch_speed", dtype=np.float32), +] # ===================================================== @@ -79,9 +81,14 @@ def _ctd_cast(particle, fieldset, time): class CTDInstrument(Instrument): """CTD instrument class.""" + sensor_kernels: ClassVar[dict[SensorType, Callable]] = { + SensorType.TEMPERATURE: _sample_temperature, + SensorType.SALINITY: _sample_salinity, + } + def __init__(self, expedition, from_data): """Initialize CTDInstrument.""" - variables = {"S": "so", "T": "thetao"} + variables = expedition.instruments_config.ctd_config.active_variables() limit_spec = { "spatial": True } # spatial limits; lat/lon constrained to waypoint locations + buffer @@ -115,11 +122,14 @@ def simulate(self, measurements, out_path) -> None: # add dummy U add_dummy_UV(fieldset) # TODO: parcels v3 bodge; remove when parcels v4 is used - fieldset_starttime = fieldset.T.grid.time_origin.fulltime( - fieldset.T.grid.time_full[0] + # use first active field for time reference + _time_ref_key = next(iter(self.variables)) + _time_ref_field = getattr(fieldset, _time_ref_key) + fieldset_starttime = _time_ref_field.grid.time_origin.fulltime( + _time_ref_field.grid.time_full[0] ) - fieldset_endtime = fieldset.T.grid.time_origin.fulltime( - fieldset.T.grid.time_full[-1] + fieldset_endtime = _time_ref_field.grid.time_origin.fulltime( + _time_ref_field.grid.time_full[-1] ) # deploy time for all ctds should be later than fieldset start time @@ -152,6 +162,12 @@ def simulate(self, measurements, out_path) -> None: f"CTD max_depth or bathymetry shallower than maximum {-DT * WINCH_SPEED}" ) + # build dynamic particle class from the active sensors + ctd_config = self.expedition.instruments_config.ctd_config + _CTDParticle = build_particle_class_from_sensors( + ctd_config.sensors, _CTD_NONSENSOR_VARIABLES, JITParticle + ) + # define parcel particles ctd_particleset = ParticleSet( fieldset=fieldset, @@ -168,9 +184,16 @@ def simulate(self, measurements, out_path) -> None: # define output file for the simulation out_file = ctd_particleset.ParticleFile(name=out_path, outputdt=OUTPUT_DT) + # build kernel list from active sensors only + sampling_kernels = [ + self.sensor_kernels[sc.sensor_type] + for sc in ctd_config.sensors + if sc.enabled and sc.sensor_type in self.sensor_kernels + ] + # execute simulation ctd_particleset.execute( - [_sample_salinity, _sample_temperature, _ctd_cast], + [*sampling_kernels, _ctd_cast], endtime=fieldset_endtime, dt=DT, verbose_progress=self.verbose_progress, diff --git a/src/virtualship/instruments/ctd_bgc.py b/src/virtualship/instruments/ctd_bgc.py index 221cfa12..3568d0a8 100644 --- a/src/virtualship/instruments/ctd_bgc.py +++ b/src/virtualship/instruments/ctd_bgc.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta from typing import ClassVar @@ -6,9 +7,14 @@ from parcels import JITParticle, ParticleSet, Variable from virtualship.instruments.base import Instrument +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType from virtualship.models.spacetime import Spacetime -from virtualship.utils import add_dummy_UV, register_instrument +from virtualship.utils import ( + add_dummy_UV, + build_particle_class_from_sensors, + register_instrument, +) # ===================================================== # SECTION: Dataclass @@ -26,24 +32,15 @@ class CTD_BGC: # ===================================================== -# SECTION: Particle Class +# SECTION: non-sensor Particle Variables (non-sampling) # ===================================================== -_CTD_BGCParticle = JITParticle.add_variables( - [ - Variable("o2", dtype=np.float32, initial=np.nan), - Variable("chl", dtype=np.float32, initial=np.nan), - Variable("no3", dtype=np.float32, initial=np.nan), - Variable("po4", dtype=np.float32, initial=np.nan), - Variable("ph", dtype=np.float32, initial=np.nan), - Variable("phyc", dtype=np.float32, initial=np.nan), - Variable("nppv", dtype=np.float32, initial=np.nan), - Variable("raising", dtype=np.int8, initial=0.0), # bool. 0 is False, 1 is True. - Variable("max_depth", dtype=np.float32), - Variable("min_depth", dtype=np.float32), - Variable("winch_speed", dtype=np.float32), - ] -) +_CTD_BGC_NONSENSOR_VARIABLES = [ + Variable("raising", dtype=np.int8, initial=0.0), # bool. 0 is False, 1 is True. + Variable("max_depth", dtype=np.float32), + Variable("min_depth", dtype=np.float32), + Variable("winch_speed", dtype=np.float32), +] # ===================================================== # SECTION: Kernels @@ -101,17 +98,19 @@ def _ctd_bgc_cast(particle, fieldset, time): class CTD_BGCInstrument(Instrument): """CTD_BGC instrument class.""" + sensor_kernels: ClassVar[dict[SensorType, Callable]] = { + SensorType.OXYGEN: _sample_o2, + SensorType.CHLOROPHYLL: _sample_chlorophyll, + SensorType.NITRATE: _sample_nitrate, + SensorType.PHOSPHATE: _sample_phosphate, + SensorType.PH: _sample_ph, + SensorType.PHYTOPLANKTON: _sample_phytoplankton, + SensorType.PRIMARY_PRODUCTION: _sample_primary_production, + } + def __init__(self, expedition, from_data): """Initialize CTD_BGCInstrument.""" - variables = { - "o2": "o2", - "chl": "chl", - "no3": "no3", - "po4": "po4", - "ph": "ph", - "phyc": "phyc", - "nppv": "nppv", - } + variables = expedition.instruments_config.ctd_bgc_config.active_variables() limit_spec = { "spatial": True } # spatial limits; lat/lon constrained to waypoint locations + buffer @@ -145,11 +144,14 @@ def simulate(self, measurements, out_path) -> None: # add dummy U add_dummy_UV(fieldset) # TODO: parcels v3 bodge; remove when parcels v4 is used - fieldset_starttime = fieldset.o2.grid.time_origin.fulltime( - fieldset.o2.grid.time_full[0] + # use first active field for time reference + _time_ref_key = next(iter(self.variables)) + _time_ref_field = getattr(fieldset, _time_ref_key) + fieldset_starttime = _time_ref_field.grid.time_origin.fulltime( + _time_ref_field.grid.time_full[0] ) - fieldset_endtime = fieldset.o2.grid.time_origin.fulltime( - fieldset.o2.grid.time_full[-1] + fieldset_endtime = _time_ref_field.grid.time_origin.fulltime( + _time_ref_field.grid.time_full[-1] ) # deploy time for all ctds should be later than fieldset start time @@ -182,6 +184,12 @@ def simulate(self, measurements, out_path) -> None: f"BGC CTD max_depth or bathymetry shallower than maximum {-DT * WINCH_SPEED}" ) + # build dynamic particle class from the active sensors + ctd_bgc_config = self.expedition.instruments_config.ctd_bgc_config + _CTD_BGCParticle = build_particle_class_from_sensors( + ctd_bgc_config.sensors, _CTD_BGC_NONSENSOR_VARIABLES, JITParticle + ) + # define parcel particles ctd_bgc_particleset = ParticleSet( fieldset=fieldset, @@ -198,18 +206,16 @@ def simulate(self, measurements, out_path) -> None: # define output file for the simulation out_file = ctd_bgc_particleset.ParticleFile(name=out_path, outputdt=OUTPUT_DT) + # build kernel list from active sensors only + sampling_kernels = [ + self.sensor_kernels[sc.sensor_type] + for sc in ctd_bgc_config.sensors + if sc.enabled and sc.sensor_type in self.sensor_kernels + ] + # execute simulation ctd_bgc_particleset.execute( - [ - _sample_o2, - _sample_chlorophyll, - _sample_nitrate, - _sample_phosphate, - _sample_ph, - _sample_phytoplankton, - _sample_primary_production, - _ctd_bgc_cast, - ], + [*sampling_kernels, _ctd_bgc_cast], endtime=fieldset_endtime, dt=DT, verbose_progress=self.verbose_progress, diff --git a/src/virtualship/instruments/drifter.py b/src/virtualship/instruments/drifter.py index a58c4bef..379334b3 100644 --- a/src/virtualship/instruments/drifter.py +++ b/src/virtualship/instruments/drifter.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta from typing import ClassVar @@ -6,9 +7,14 @@ from parcels import AdvectionRK4, JITParticle, ParticleSet, Variable from virtualship.instruments.base import Instrument +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType from virtualship.models.spacetime import Spacetime -from virtualship.utils import _random_noise, register_instrument +from virtualship.utils import ( + _random_noise, + build_particle_class_from_sensors, + register_instrument, +) # ===================================================== # SECTION: Dataclass @@ -26,17 +32,14 @@ class Drifter: # ===================================================== -# SECTION: Particle Class +# SECTION: non-sensor Particle Variables (non-sampling) # ===================================================== -_DrifterParticle = JITParticle.add_variables( - [ - Variable("temperature", dtype=np.float32, initial=np.nan), - Variable("has_lifetime", dtype=np.int8), # bool - Variable("age", dtype=np.float32, initial=0.0), - Variable("lifetime", dtype=np.float32), - ] -) +_DRIFTER_NONSENSOR_VARIABLES = [ + Variable("has_lifetime", dtype=np.int8), # bool + Variable("age", dtype=np.float32, initial=0.0), + Variable("lifetime", dtype=np.float32), +] # ===================================================== # SECTION: Kernels @@ -63,9 +66,20 @@ def _check_lifetime(particle, fieldset, time): class DrifterInstrument(Instrument): """Drifter instrument class.""" + sensor_kernels: ClassVar[dict[SensorType, Callable]] = { + SensorType.TEMPERATURE: _sample_temperature, + } + def __init__(self, expedition, from_data): """Initialize DrifterInstrument.""" - variables = {"U": "uo", "V": "vo", "T": "thetao"} + sensor_variables = ( + expedition.instruments_config.drifter_config.active_variables() + ) + variables = { + "U": "uo", + "V": "vo", + **sensor_variables, + } # advection variables (U and V) are always required for drifter simulation; sensor variables come from config spacetime_buffer_size = { "latlon": None, "time": expedition.instruments_config.drifter_config.lifetime.total_seconds() @@ -106,6 +120,12 @@ def simulate(self, measurements, out_path) -> None: fieldset = self.load_input_data() + # build dynamic particle class from the active sensors + drifter_config = self.expedition.instruments_config.drifter_config + _DrifterParticle = build_particle_class_from_sensors( + drifter_config.sensors, _DRIFTER_NONSENSOR_VARIABLES, JITParticle + ) + # define parcel particles lat_release = [ drifter.spacetime.location.lat + _random_noise() for drifter in measurements @@ -140,9 +160,16 @@ def simulate(self, measurements, out_path) -> None: # determine end time for simulation, from fieldset (which itself is controlled by drifter lifetimes) endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1]) + # build kernel list from active sensors only + sampling_kernels = [ + self.sensor_kernels[sc.sensor_type] + for sc in drifter_config.sensors + if sc.enabled and sc.sensor_type in self.sensor_kernels + ] + # execute simulation drifter_particleset.execute( - [AdvectionRK4, _sample_temperature, _check_lifetime], + [AdvectionRK4, *sampling_kernels, _check_lifetime], endtime=endtime, dt=DT, output_file=out_file, diff --git a/src/virtualship/instruments/sensors.py b/src/virtualship/instruments/sensors.py new file mode 100644 index 00000000..2db148d8 --- /dev/null +++ b/src/virtualship/instruments/sensors.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache +from typing import Literal + +import numpy as np +from parcels import Variable + + +class SensorType(str, Enum): + """Sensors available. Different intstruments mix and match these sensors as needed.""" + + TEMPERATURE = "TEMPERATURE" + SALINITY = "SALINITY" + VELOCITY = "VELOCITY" + OXYGEN = "OXYGEN" + CHLOROPHYLL = "CHLOROPHYLL" + NITRATE = "NITRATE" + PHOSPHATE = "PHOSPHATE" + PH = "PH" + PHYTOPLANKTON = "PHYTOPLANKTON" + PRIMARY_PRODUCTION = "PRIMARY_PRODUCTION" + + +@dataclass(frozen=True) +class _Sensor: + type_: SensorType + fs_key: str # map to Parcels fieldset variables + copernicus_var: str # map to Copernicus Marine Service variable names + category: Literal[ + "phys", "bgc" + ] # physical vs. biogeochemical variable, used for product ID selection logic + particle_vars: tuple[Variable, ...] # parcels.Variable(s) produced by this sensor + + +@lru_cache(maxsize=1) # cache here so same dict is not rebuilt on every access +def SENSOR_REGISTRY() -> dict[SensorType, _Sensor]: + """Cached accessor for the sensor registry (lazily via _build_sensor_registry, avoids circular import errors).""" + return _build_sensor_registry() + + +# the copernicus_var field below is the bridge between this registry the Copernicus product-ID selection logic (PRODUCT_IDS, BGC_ANALYSIS_IDS, MONTHLY_BGC_REANALYSIS_IDS, etc.) +def _build_sensor_registry() -> dict[SensorType, _Sensor]: + return { + s.type_: s + for s in [ + _Sensor( + type_=SensorType.TEMPERATURE, + fs_key="T", + copernicus_var="thetao", + category="phys", + particle_vars=( + Variable("temperature", dtype=np.float32, initial=np.nan), + ), + ), + _Sensor( + type_=SensorType.SALINITY, + fs_key="S", + copernicus_var="so", + category="phys", + particle_vars=(Variable("salinity", dtype=np.float32, initial=np.nan),), + ), + _Sensor( + type_=SensorType.VELOCITY, + fs_key="UV", + copernicus_var="uo", # uo is primary var here... active_variables() in ADCPConfig expands to both uo and vo + category="phys", + particle_vars=( + Variable("U", dtype=np.float32, initial=np.nan), + Variable("V", dtype=np.float32, initial=np.nan), + ), # two particle variables associated with one sensor + ), + _Sensor( + type_=SensorType.OXYGEN, + fs_key="o2", + copernicus_var="o2", + category="bgc", + particle_vars=(Variable("o2", dtype=np.float32, initial=np.nan),), + ), + _Sensor( + type_=SensorType.CHLOROPHYLL, + fs_key="chl", + copernicus_var="chl", + category="bgc", + particle_vars=(Variable("chl", dtype=np.float32, initial=np.nan),), + ), + _Sensor( + type_=SensorType.NITRATE, + fs_key="no3", + copernicus_var="no3", + category="bgc", + particle_vars=(Variable("no3", dtype=np.float32, initial=np.nan),), + ), + _Sensor( + type_=SensorType.PHOSPHATE, + fs_key="po4", + copernicus_var="po4", + category="bgc", + particle_vars=(Variable("po4", dtype=np.float32, initial=np.nan),), + ), + _Sensor( + type_=SensorType.PH, + fs_key="ph", + copernicus_var="ph", + category="bgc", + particle_vars=(Variable("ph", dtype=np.float32, initial=np.nan),), + ), + _Sensor( + type_=SensorType.PHYTOPLANKTON, + fs_key="phyc", + copernicus_var="phyc", + category="bgc", + particle_vars=(Variable("phyc", dtype=np.float32, initial=np.nan),), + ), + _Sensor( + type_=SensorType.PRIMARY_PRODUCTION, + fs_key="nppv", + copernicus_var="nppv", + category="bgc", + particle_vars=(Variable("nppv", dtype=np.float32, initial=np.nan),), + ), + ] + } diff --git a/src/virtualship/instruments/ship_underwater_st.py b/src/virtualship/instruments/ship_underwater_st.py index 8b7ef96d..6a564cc0 100644 --- a/src/virtualship/instruments/ship_underwater_st.py +++ b/src/virtualship/instruments/ship_underwater_st.py @@ -1,12 +1,18 @@ +from collections.abc import Callable from dataclasses import dataclass from typing import ClassVar import numpy as np -from parcels import ParticleSet, ScipyParticle, Variable +from parcels import ParticleSet, ScipyParticle from virtualship.instruments.base import Instrument +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType -from virtualship.utils import add_dummy_UV, register_instrument +from virtualship.utils import ( + add_dummy_UV, + build_particle_class_from_sensors, + register_instrument, +) # ===================================================== # SECTION: Dataclass @@ -21,15 +27,12 @@ class Underwater_ST: # ===================================================== -# SECTION: Particle Class +# SECTION: non-sensor Particle Variables (non-sampling) # ===================================================== -_ShipSTParticle = ScipyParticle.add_variables( - [ - Variable("S", dtype=np.float32, initial=np.nan), - Variable("T", dtype=np.float32, initial=np.nan), - ] -) +# Underwater ST has no non-sensor variables, only sensor variables. +_ST_NONSENSOR_VARIABLES: list = [] + # ===================================================== # SECTION: Kernels @@ -38,12 +41,12 @@ class Underwater_ST: # define function sampling Salinity def _sample_salinity(particle, fieldset, time): - particle.S = fieldset.S[time, particle.depth, particle.lat, particle.lon] + particle.salinity = fieldset.S[time, particle.depth, particle.lat, particle.lon] # define function sampling Temperature def _sample_temperature(particle, fieldset, time): - particle.T = fieldset.T[time, particle.depth, particle.lat, particle.lon] + particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] # ===================================================== @@ -55,9 +58,16 @@ def _sample_temperature(particle, fieldset, time): class Underwater_STInstrument(Instrument): """Underwater_ST instrument class.""" + sensor_kernels: ClassVar[dict[SensorType, Callable]] = { + SensorType.TEMPERATURE: _sample_temperature, + SensorType.SALINITY: _sample_salinity, + } + def __init__(self, expedition, from_data): """Initialize Underwater_STInstrument.""" - variables = {"S": "so", "T": "thetao"} + variables = ( + expedition.instruments_config.ship_underwater_st_config.active_variables() + ) spacetime_buffer_size = { "latlon": 0.25, # [degrees] "time": 0.0, # [days] @@ -88,6 +98,12 @@ def simulate(self, measurements, out_path) -> None: # add dummy U add_dummy_UV(fieldset) # TODO: parcels v3 bodge; remove when parcels v4 is used + # build dynamic particle class from the active sensors + st_config = self.expedition.instruments_config.ship_underwater_st_config + _ShipSTParticle = build_particle_class_from_sensors( + st_config.sensors, _ST_NONSENSOR_VARIABLES, ScipyParticle + ) + particleset = ParticleSet.from_list( fieldset=fieldset, pclass=_ShipSTParticle, @@ -99,6 +115,13 @@ def simulate(self, measurements, out_path) -> None: out_file = particleset.ParticleFile(name=out_path, outputdt=np.inf) + # build kernel list from active sensors only + sampling_kernels = [ + self.sensor_kernels[sc.sensor_type] + for sc in st_config.sensors + if sc.enabled and sc.sensor_type in self.sensor_kernels + ] + for point in measurements: particleset.lon_nextloop[:] = point.location.lon particleset.lat_nextloop[:] = point.location.lat @@ -107,7 +130,7 @@ def simulate(self, measurements, out_path) -> None: ) particleset.execute( - [_sample_salinity, _sample_temperature], + sampling_kernels, dt=1, runtime=1, verbose_progress=self.verbose_progress, diff --git a/src/virtualship/instruments/types.py b/src/virtualship/instruments/types.py index 9ae221e9..489a331f 100644 --- a/src/virtualship/instruments/types.py +++ b/src/virtualship/instruments/types.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum diff --git a/src/virtualship/instruments/xbt.py b/src/virtualship/instruments/xbt.py index 2412306f..051bf1fa 100644 --- a/src/virtualship/instruments/xbt.py +++ b/src/virtualship/instruments/xbt.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta from typing import ClassVar @@ -6,9 +7,14 @@ from parcels import JITParticle, ParticleSet, Variable from virtualship.instruments.base import Instrument +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType from virtualship.models.spacetime import Spacetime -from virtualship.utils import add_dummy_UV, register_instrument +from virtualship.utils import ( + add_dummy_UV, + build_particle_class_from_sensors, + register_instrument, +) # ===================================================== # SECTION: Dataclass @@ -28,18 +34,16 @@ class XBT: # ===================================================== -# SECTION: Particle Class +# SECTION: non-sensor Particle Variables (non-sampling) # ===================================================== -_XBTParticle = JITParticle.add_variables( - [ - Variable("temperature", dtype=np.float32, initial=np.nan), - Variable("max_depth", dtype=np.float32), - Variable("min_depth", dtype=np.float32), - Variable("fall_speed", dtype=np.float32), - Variable("deceleration_coefficient", dtype=np.float32), - ] -) +_XBT_NONSENSOR_VARIABLES = [ + Variable("max_depth", dtype=np.float32), + Variable("min_depth", dtype=np.float32), + Variable("fall_speed", dtype=np.float32), + Variable("deceleration_coefficient", dtype=np.float32), +] + # ===================================================== # SECTION: Kernels @@ -77,9 +81,13 @@ def _xbt_cast(particle, fieldset, time): class XBTInstrument(Instrument): """XBT instrument class.""" + sensor_kernels: ClassVar[dict[SensorType, Callable]] = { + SensorType.TEMPERATURE: _sample_temperature, + } + def __init__(self, expedition, from_data): """Initialize XBTInstrument.""" - variables = {"T": "thetao"} + variables = expedition.instruments_config.xbt_config.active_variables() limit_spec = { "spatial": True } # spatial limits; lat/lon constrained to waypoint locations + buffer @@ -112,11 +120,14 @@ def simulate(self, measurements, out_path) -> None: # add dummy U add_dummy_UV(fieldset) # TODO: parcels v3 bodge; remove when parcels v4 is used - fieldset_starttime = fieldset.T.grid.time_origin.fulltime( - fieldset.T.grid.time_full[0] + # use first active field for time reference + _time_ref_key = next(iter(self.variables)) + _time_ref_field = getattr(fieldset, _time_ref_key) + fieldset_starttime = _time_ref_field.grid.time_origin.fulltime( + _time_ref_field.grid.time_full[0] ) - fieldset_endtime = fieldset.T.grid.time_origin.fulltime( - fieldset.T.grid.time_full[-1] + fieldset_endtime = _time_ref_field.grid.time_origin.fulltime( + _time_ref_field.grid.time_full[-1] ) # deploy time for all xbts should be later than fieldset start time @@ -152,6 +163,12 @@ def simulate(self, measurements, out_path) -> None: f"XBT max_depth or bathymetry shallower than minimum {-DT * fall_speed}. It is likely the XBT cannot be deployed in this area, which is too shallow." ) + # build dynamic particle class from the active sensors + xbt_config = self.expedition.instruments_config.xbt_config + _XBTParticle = build_particle_class_from_sensors( + xbt_config.sensors, _XBT_NONSENSOR_VARIABLES, JITParticle + ) + # define xbt particles xbt_particleset = ParticleSet( fieldset=fieldset, @@ -167,8 +184,15 @@ def simulate(self, measurements, out_path) -> None: out_file = xbt_particleset.ParticleFile(name=out_path, outputdt=OUTPUT_DT) + # build kernel list from active sensors only + sampling_kernels = [ + self.sensor_kernels[sc.sensor_type] + for sc in xbt_config.sensors + if sc.enabled and sc.sensor_type in self.sensor_kernels + ] + xbt_particleset.execute( - [_sample_temperature, _xbt_cast], + [*sampling_kernels, _xbt_cast], endtime=fieldset_endtime, dt=DT, verbose_progress=self.verbose_progress, diff --git a/src/virtualship/make_realistic/ctd_make_realistic.py b/src/virtualship/make_realistic/ctd_make_realistic.py index aeae2ab7..21c85cc8 100644 --- a/src/virtualship/make_realistic/ctd_make_realistic.py +++ b/src/virtualship/make_realistic/ctd_make_realistic.py @@ -18,7 +18,7 @@ def ctd_make_realistic( :param zarr_path: Input simulated data. :param out_dir: Output directory for CNV file. - :param prefix: Prefix for CNV files. Will be postfixed with '_{ctd_num}'. + :param prefix: Prefix for CNV files. Will be postnonsensor with '_{ctd_num}'. :returns: Paths to created file. """ original = xr.open_zarr(zarr_path) diff --git a/src/virtualship/models/expedition.py b/src/virtualship/models/expedition.py index 5d16ecf5..32855fc9 100644 --- a/src/virtualship/models/expedition.py +++ b/src/virtualship/models/expedition.py @@ -3,6 +3,7 @@ import itertools from datetime import datetime, timedelta from pathlib import Path +from typing import ClassVar import numpy as np import pydantic @@ -10,6 +11,7 @@ import yaml from virtualship.errors import InstrumentsConfigError, ScheduleError +from virtualship.instruments.sensors import SENSOR_REGISTRY, SensorType, _Sensor from virtualship.instruments.types import InstrumentType from virtualship.utils import ( _calc_sail_time, @@ -17,6 +19,7 @@ _get_bathy_data, _get_waypoint_latlons, _validate_numeric_to_timedelta, + get_supported_sensors, register_instrument_config, ) @@ -208,10 +211,58 @@ def serialize_instrument(self, instrument): return instrument.value if instrument else None +## + + +class _InstrumentConfigMixin(pydantic.BaseModel): + """Serialisation, validation and variable mapping inheritance across instrument configs.""" + + _instrument_type: ClassVar[InstrumentType] + _instrument_name: ClassVar[str] + + @pydantic.field_validator("sensors", mode="after", check_fields=False) + @classmethod + def _check_sensors(cls, value) -> list[SensorConfig]: + return SensorConfig.check_compatibility( + value, cls._instrument_type, cls._instrument_name + ) + + @pydantic.field_serializer("sensors", check_fields=False) + def _serialize_sensors(self, value: list[SensorConfig], _info): + return SensorConfig.serialize_list(value) + + def active_variables(self) -> dict[str, str]: + """FieldSet-key → Copernicus-variable mapping for enabled sensors.""" + return SensorConfig.build_variables(self.sensors) + + @pydantic.field_serializer("stationkeeping_time", "period", check_fields=False) + def _serialize_minutes(self, value: timedelta, _info) -> float: + return value.total_seconds() / 60.0 + + @pydantic.field_validator( + "stationkeeping_time", "period", mode="before", check_fields=False + ) + @classmethod + def _validate_minutes(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_to_timedelta(value, "minutes") + + @pydantic.field_serializer("lifetime", check_fields=False) + def _serialize_lifetime(self, value: timedelta, _info) -> float: + return value.total_seconds() / 86400.0 # [days] + + @pydantic.field_validator("lifetime", mode="before", check_fields=False) + @classmethod + def _validate_lifetime(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_to_timedelta(value, "days") + + @register_instrument_config(InstrumentType.ARGO_FLOAT) -class ArgoFloatConfig(pydantic.BaseModel): +class ArgoFloatConfig(_InstrumentConfigMixin, pydantic.BaseModel): """Configuration for argos floats.""" + _instrument_type: ClassVar[InstrumentType] = InstrumentType.ARGO_FLOAT + _instrument_name: ClassVar[str] = "ArgoFloat" + min_depth_meter: float = pydantic.Field(le=0.0) max_depth_meter: float = pydantic.Field(le=0.0) drift_depth_meter: float = pydantic.Field(le=0.0) @@ -230,29 +281,26 @@ class ArgoFloatConfig(pydantic.BaseModel): gt=timedelta(), ) - @pydantic.field_serializer("lifetime") - def _serialize_lifetime(self, value: timedelta, _info): - return value.total_seconds() / 86400.0 # [days] - - @pydantic.field_validator("lifetime", mode="before") - def _validate_lifetime(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_to_timedelta(value, "days") - - @pydantic.field_serializer("stationkeeping_time") - def _serialize_stationkeeping_time(self, value: timedelta, _info): - return value.total_seconds() / 60.0 - - @pydantic.field_validator("stationkeeping_time", mode="before") - def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_to_timedelta(value, "minutes") + sensors: list[SensorConfig] = pydantic.Field( + default_factory=lambda: [ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.SALINITY), + ], + description=( + "Sensors fitted to the Argo float. Supported: TEMPERATURE, SALINITY. " + ), + ) model_config = pydantic.ConfigDict(populate_by_name=True) @register_instrument_config(InstrumentType.ADCP) -class ADCPConfig(pydantic.BaseModel): +class ADCPConfig(_InstrumentConfigMixin, pydantic.BaseModel): """Configuration for ADCP instrument.""" + _instrument_type: ClassVar[InstrumentType] = InstrumentType.ADCP + _instrument_name: ClassVar[str] = "ADCP" + max_depth_meter: float = pydantic.Field(le=0.0) num_bins: int = pydantic.Field(gt=0.0) period: timedelta = pydantic.Field( @@ -261,21 +309,37 @@ class ADCPConfig(pydantic.BaseModel): gt=timedelta(), ) + sensors: list[SensorConfig] = pydantic.Field( + default_factory=lambda: [SensorConfig(sensor_type=SensorType.VELOCITY)], + description=( + "Sensors fitted to the ADCP. " + "Supported: VELOCITY (samples both U and V components in one go)." + ), + ) + model_config = pydantic.ConfigDict(populate_by_name=True) - @pydantic.field_serializer("period") - def _serialize_period(self, value: timedelta, _info): - return value.total_seconds() / 60.0 + def active_variables(self) -> dict[str, str]: + """ + FieldSet-key → Copernicus-variable mapping for enabled sensors. - @pydantic.field_validator("period", mode="before") - def _validate_period(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_to_timedelta(value, "minutes") + VELOCITY is a special case: one sensor provides two FieldSet variables (U and V). + """ + variables = {} + for sc in self.sensors: + if sc.enabled and sc.sensor_type == SensorType.VELOCITY: + variables["U"] = "uo" + variables["V"] = "vo" + return variables @register_instrument_config(InstrumentType.CTD) -class CTDConfig(pydantic.BaseModel): +class CTDConfig(_InstrumentConfigMixin, pydantic.BaseModel): """Configuration for CTD instrument.""" + _instrument_type: ClassVar[InstrumentType] = InstrumentType.CTD + _instrument_name: ClassVar[str] = "CTD" + stationkeeping_time: timedelta = pydantic.Field( serialization_alias="stationkeeping_time_minutes", validation_alias="stationkeeping_time_minutes", @@ -284,21 +348,24 @@ class CTDConfig(pydantic.BaseModel): min_depth_meter: float = pydantic.Field(le=0.0) max_depth_meter: float = pydantic.Field(le=0.0) - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("stationkeeping_time") - def _serialize_stationkeeping_time(self, value: timedelta, _info): - return value.total_seconds() / 60.0 + sensors: list[SensorConfig] = pydantic.Field( + default_factory=lambda: [ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.SALINITY), + ], + description=("Sensors fitted to the CTD. Supported: TEMPERATURE, SALINITY. "), + ) - @pydantic.field_validator("stationkeeping_time", mode="before") - def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_to_timedelta(value, "minutes") + model_config = pydantic.ConfigDict(populate_by_name=True) @register_instrument_config(InstrumentType.CTD_BGC) -class CTD_BGCConfig(pydantic.BaseModel): +class CTD_BGCConfig(_InstrumentConfigMixin, pydantic.BaseModel): """Configuration for CTD_BGC instrument.""" + _instrument_type: ClassVar[InstrumentType] = InstrumentType.CTD_BGC + _instrument_name: ClassVar[str] = "CTD_BGC" + stationkeeping_time: timedelta = pydantic.Field( serialization_alias="stationkeeping_time_minutes", validation_alias="stationkeeping_time_minutes", @@ -307,42 +374,58 @@ class CTD_BGCConfig(pydantic.BaseModel): min_depth_meter: float = pydantic.Field(le=0.0) max_depth_meter: float = pydantic.Field(le=0.0) - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("stationkeeping_time") - def _serialize_stationkeeping_time(self, value: timedelta, _info): - return value.total_seconds() / 60.0 + sensors: list[SensorConfig] = pydantic.Field( + default_factory=lambda: [ + SensorConfig(sensor_type=SensorType.OXYGEN), + SensorConfig(sensor_type=SensorType.CHLOROPHYLL), + SensorConfig(sensor_type=SensorType.NITRATE), + SensorConfig(sensor_type=SensorType.PHOSPHATE), + SensorConfig(sensor_type=SensorType.PH), + SensorConfig(sensor_type=SensorType.PHYTOPLANKTON), + SensorConfig(sensor_type=SensorType.PRIMARY_PRODUCTION), + ], + description=( + "Sensors fitted to the BGC CTD. " + "Supported: OXYGEN, CHLOROPHYLL, NITRATE, PHOSPHATE, PH, PHYTOPLANKTON, PRIMARY_PRODUCTION. " + ), + ) - @pydantic.field_validator("stationkeeping_time", mode="before") - def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_to_timedelta(value, "minutes") + model_config = pydantic.ConfigDict(populate_by_name=True) @register_instrument_config(InstrumentType.UNDERWATER_ST) -class ShipUnderwaterSTConfig(pydantic.BaseModel): +class ShipUnderwaterSTConfig(_InstrumentConfigMixin, pydantic.BaseModel): """Configuration for underwater ST.""" + _instrument_type: ClassVar[InstrumentType] = InstrumentType.UNDERWATER_ST + _instrument_name: ClassVar[str] = "Underwater ST" + period: timedelta = pydantic.Field( serialization_alias="period_minutes", validation_alias="period_minutes", gt=timedelta(), ) - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("period") - def _serialize_period(self, value: timedelta, _info): - return value.total_seconds() / 60.0 + sensors: list[SensorConfig] = pydantic.Field( + default_factory=lambda: [ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.SALINITY), + ], + description=( + "Sensors fitted to the underway ST. Supported: TEMPERATURE, SALINITY. " + ), + ) - @pydantic.field_validator("period", mode="before") - def _validate_period(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_to_timedelta(value, "minutes") + model_config = pydantic.ConfigDict(populate_by_name=True) @register_instrument_config(InstrumentType.DRIFTER) -class DrifterConfig(pydantic.BaseModel): +class DrifterConfig(_InstrumentConfigMixin, pydantic.BaseModel): """Configuration for drifters.""" + _instrument_type: ClassVar[InstrumentType] = InstrumentType.DRIFTER + _instrument_name: ClassVar[str] = "Drifter" + depth_meter: float = pydantic.Field(le=0.0) lifetime: timedelta = pydantic.Field( serialization_alias="lifetime_days", @@ -355,34 +438,31 @@ class DrifterConfig(pydantic.BaseModel): gt=timedelta(), ) - model_config = pydantic.ConfigDict(populate_by_name=True) - - @pydantic.field_serializer("lifetime") - def _serialize_lifetime(self, value: timedelta, _info): - return value.total_seconds() / 86400.0 # [days] - - @pydantic.field_validator("lifetime", mode="before") - def _validate_lifetime(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_to_timedelta(value, "days") - - @pydantic.field_serializer("stationkeeping_time") - def _serialize_stationkeeping_time(self, value: timedelta, _info): - return value.total_seconds() / 60.0 + sensors: list[SensorConfig] = pydantic.Field( + default_factory=lambda: [SensorConfig(sensor_type=SensorType.TEMPERATURE)], + description=("Sensors fitted to the drifter. Supported: TEMPERATURE. "), + ) - @pydantic.field_validator("stationkeeping_time", mode="before") - def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: - return _validate_numeric_to_timedelta(value, "minutes") + model_config = pydantic.ConfigDict(populate_by_name=True) @register_instrument_config(InstrumentType.XBT) -class XBTConfig(pydantic.BaseModel): +class XBTConfig(_InstrumentConfigMixin, pydantic.BaseModel): """Configuration for xbt instrument.""" + _instrument_type: ClassVar[InstrumentType] = InstrumentType.XBT + _instrument_name: ClassVar[str] = "XBT" + min_depth_meter: float = pydantic.Field(le=0.0) max_depth_meter: float = pydantic.Field(le=0.0) fall_speed_meter_per_second: float = pydantic.Field(gt=0.0) deceleration_coefficient: float = pydantic.Field(gt=0.0) + sensors: list[SensorConfig] = pydantic.Field( + default_factory=lambda: [SensorConfig(sensor_type=SensorType.TEMPERATURE)], + description=("Sensors fitted to the XBT. Supported: TEMPERATURE. "), + ) + class InstrumentsConfig(pydantic.BaseModel): """Configuration of instruments.""" @@ -473,3 +553,65 @@ def verify(self, expedition: Expedition) -> None: raise InstrumentsConfigError( f"Expedition includes instrument '{inst_type.value}', but instruments_config does not provide configuration for it." ) + + +class SensorConfig(pydantic.BaseModel): + """Configuration for a single sensor fitted to an instrument.""" + + sensor_type: SensorType + enabled: bool = True + + # validator/serialiser for allowing the compact, single-string notation for sensors in YAML (e.g. "TEMPERATURE" instead of sensor_type: TEMPERATURE in each instance + @pydantic.model_validator(mode="before") + @classmethod + def _from_string(cls, value): + """Allow a bare sensor-type string (e.g. "TEMPERATURE") as shorthand for {"sensor_type": "TEMPERATURE"}.""" + if isinstance(value, str): + return {"sensor_type": value} + return value + + @pydantic.field_validator("sensor_type", mode="before") + @classmethod + def _take_sensor_type(cls, value: str | SensorType) -> SensorType: + """Accept a sensor-type string or SensorType class.""" + if isinstance(value, SensorType): + return value + return SensorType(value) + + @property + def meta(self) -> _Sensor: + """Metadata for this sensor.""" + return SENSOR_REGISTRY()[self.sensor_type] + + @staticmethod + def serialize_list(sensors: list[SensorConfig]) -> list[str]: + """Serialise enabled sensors to a list of sensor-type strings.""" + return [sc.sensor_type.value for sc in sensors if sc.enabled] + + @staticmethod + def check_compatibility( + sensors: list[SensorConfig], + instrument_type: InstrumentType, + instrument_name: str, + ) -> list[SensorConfig]: + """Error if any sensor is unsupported for the given instrument, or none are enabled.""" + supported = get_supported_sensors(instrument_type) + unsupported = {sc.sensor_type for sc in sensors} - supported + if unsupported: + names = ", ".join(sorted(s.value for s in unsupported)) + valid = ", ".join(sorted(s.value for s in supported)) + raise ValueError( + f"{instrument_name} does not support sensor(s): {names}. " + f"Supported sensors: {valid}." + ) + if not any(sc.enabled for sc in sensors): + raise ValueError( + f"{instrument_name} has no enabled sensors. " + f"At least one sensor must be enabled." + ) + return sensors + + @staticmethod + def build_variables(sensors: list[SensorConfig]) -> dict[str, str]: + """Build a FieldSet-key → Copernicus-variable mapping for enabled sensors.""" + return {sc.meta.fs_key: sc.meta.copernicus_var for sc in sensors if sc.enabled} diff --git a/src/virtualship/static/expedition.yaml b/src/virtualship/static/expedition.yaml index c6db10bd..8ab72f8a 100644 --- a/src/virtualship/static/expedition.yaml +++ b/src/virtualship/static/expedition.yaml @@ -1,5 +1,7 @@ # see https://virtualship.readthedocs.io/en/latest/user-guide/tutorials/working_with_expedition_yaml.html for more details on how to edit this file # +# TODO: add a link to docs where lists what sensors are supported for each instrument +# schedule: waypoints: - instrument: @@ -37,6 +39,8 @@ instruments_config: num_bins: 40 max_depth_meter: -1000.0 period_minutes: 5.0 + sensors: + - VELOCITY argo_float_config: cycle_days: 10.0 drift_days: 9.0 @@ -46,23 +50,45 @@ instruments_config: vertical_speed_meter_per_second: -0.1 stationkeeping_time_minutes: 20.0 lifetime_days: 63.0 + sensors: + - TEMPERATURE + - SALINITY ctd_config: max_depth_meter: -2000.0 min_depth_meter: -11.0 stationkeeping_time_minutes: 50.0 + sensors: + - TEMPERATURE + - SALINITY ctd_bgc_config: max_depth_meter: -2000.0 min_depth_meter: -11.0 stationkeeping_time_minutes: 50.0 + sensors: + - OXYGEN + - CHLOROPHYLL + - NITRATE + - PHOSPHATE + - PH + - PHYTOPLANKTON + - PRIMARY_PRODUCTION drifter_config: depth_meter: -1.0 lifetime_days: 42.0 stationkeeping_time_minutes: 20.0 + sensors: + - TEMPERATURE xbt_config: max_depth_meter: -285.0 min_depth_meter: -2.0 fall_speed_meter_per_second: 6.7 deceleration_coefficient: 0.00225 - ship_underwater_st_config: null + sensors: + - TEMPERATURE + ship_underwater_st_config: + period_minutes: 5.0 + sensors: + - TEMPERATURE + - SALINITY ship_config: ship_speed_knots: 10.0 diff --git a/src/virtualship/utils.py b/src/virtualship/utils.py index 204e9e8f..37bb44c4 100644 --- a/src/virtualship/utils.py +++ b/src/virtualship/utils.py @@ -15,7 +15,7 @@ import numpy as np import pyproj import xarray as xr -from parcels import FieldSet +from parcels import FieldSet, Variable from virtualship.errors import CopernicusCatalogueError @@ -25,6 +25,7 @@ ) from virtualship.models import Expedition, InstrumentsConfig, Location from virtualship.models.checkpoint import Checkpoint + from virtualship.models.expedition import SensorConfig import pandas as pd import yaml @@ -52,6 +53,7 @@ EXPEDITION_ORIGINAL = "expedition_original.yaml" EXPEDITION_LATEST = "expedition_latest.yaml" + # ===================================================== # SECTION: Copernicus Marine Service constants # ===================================================== @@ -106,10 +108,17 @@ # main instrument (simulation) class registry and registration utilities INSTRUMENT_CLASS_MAP = {} +# maps InstrumentType to frozenset[SensorType], to set which sensors each instrument suppors, auto-populated by @register_instrument +SUPPORTED_SENSORS_MAP: dict = {} + def register_instrument(instrument_type): def decorator(cls): INSTRUMENT_CLASS_MAP[instrument_type] = cls + if hasattr(cls, "sensor_kernels"): # derive supported kernels from class attr + SUPPORTED_SENSORS_MAP[instrument_type] = frozenset( + cls.sensor_kernels.keys() + ) return cls return decorator @@ -119,6 +128,17 @@ def get_instrument_class(instrument_type): return INSTRUMENT_CLASS_MAP.get(instrument_type) +def get_supported_sensors(instrument_type): + """Return the frozenset of SensorTypes supported by the given InstrumentType.""" + supported = SUPPORTED_SENSORS_MAP.get(instrument_type) + if supported is None: + raise KeyError( + f"No supported sensors registered for {instrument_type!r}. " + f"Does the instrument class define a `sensor_kernels` attribute?" + ) + return supported + + # map for instrument type to instrument config (pydantic basemodel) names INSTRUMENT_CONFIG_MAP = {} @@ -617,13 +637,13 @@ def _calc_wp_stationkeeping_time( instrument_config_map: dict = INSTRUMENT_CONFIG_MAP, ) -> timedelta: """For a given waypoint (and the instruments present at this waypoint), calculate how much time is required to carry out all instrument deployments.""" - from virtualship.instruments.types import InstrumentType # avoid circular imports - # to empty list if wp instruments set to 'null' if not wp_instrument_types: wp_instrument_types = [] # TODO: this can be removed if/when CTD and CTD_BGC are merged to a single instrument + from virtualship.instruments.types import InstrumentType + both_ctd_and_bgc = ( InstrumentType.CTD in wp_instrument_types and InstrumentType.CTD_BGC in wp_instrument_types @@ -639,7 +659,7 @@ def _calc_wp_stationkeeping_time( for iconfig in valid_instrument_configs: for itype in wp_instrument_types: if ( - instrument_config_map[itype] == iconfig.__class__.__name__ + instrument_config_map.get(itype) == iconfig.__class__.__name__ and ( iconfig not in wp_instrument_configs ) # avoid duplicates (would happen when multiple drifter deployments at same waypoint) @@ -649,13 +669,10 @@ def _calc_wp_stationkeeping_time( # get wp total stationkeeping time cumulative_stationkeeping_time = timedelta() for iconfig in wp_instrument_configs: - if ( - both_ctd_and_bgc - and iconfig.__class__.__name__ - == INSTRUMENT_CONFIG_MAP[InstrumentType.CTD_BGC] + if both_ctd_and_bgc and iconfig.__class__.__name__ == instrument_config_map.get( + InstrumentType.CTD_BGC ): - continue # only need to add time cost once if both CTD and CTD_BGC are being taken; in reality they would be done on the same instrument - + continue # only count stationkeeping once when both CTD and CTD_BGC are present; in reality they would be done on the same instrument if hasattr(iconfig, "stationkeeping_time"): cumulative_stationkeeping_time += iconfig.stationkeeping_time @@ -669,6 +686,19 @@ def _make_hash(s: str, length: int) -> str: return hashlib.shake_128(s.encode("utf-8")).hexdigest(half_length) +def build_particle_class_from_sensors( + sensors: list[SensorConfig], + nonsensor_variables: list[Variable], + particle_class: type, # generic type annotation needed for v3 particle class behaviour # TODO: Update with Parcels v4 +) -> type: + """Build a Particle class (JITParticle or ScipyParticle) from nonsensor variables and active sensors.""" + sensor_variables = [ + variable for sc in sensors if sc.enabled for variable in sc.meta.particle_vars + ] + + return particle_class.add_variables(nonsensor_variables + sensor_variables) + + # ===================================================== # SECTION: misc. # ===================================================== diff --git a/tests/instruments/test_adcp.py b/tests/instruments/test_adcp.py index 0a88b206..48e9e17c 100644 --- a/tests/instruments/test_adcp.py +++ b/tests/instruments/test_adcp.py @@ -3,20 +3,21 @@ import datetime import numpy as np +import pydantic +import pytest import xarray as xr from parcels import FieldSet from virtualship.instruments.adcp import ADCPInstrument +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType from virtualship.models import Location, Spacetime, Waypoint +from virtualship.models.expedition import ADCPConfig, InstrumentsConfig, SensorConfig def test_simulate_adcp(tmpdir) -> None: - # maximum depth the ADCP can measure - MAX_DEPTH = -1000 # -1000 - # minimum depth the ADCP can measure - MIN_DEPTH = -5 # -5 - # How many samples to take in the complete range between max_depth and min_depth. + MAX_DEPTH = -1000 + MIN_DEPTH = -5 NUM_BINS = 40 # arbitrary time offset for the dummy fieldset @@ -90,10 +91,14 @@ class schedule: ), ] - class instruments_config: - class adcp_config: - max_depth_meter = MAX_DEPTH - num_bins = NUM_BINS + instruments_config = InstrumentsConfig( + adcp_config=ADCPConfig( + max_depth_meter=MAX_DEPTH, + num_bins=NUM_BINS, + period_minutes=5.0, + sensors=[SensorConfig(sensor_type=SensorType.VELOCITY)], + ) + ) expedition = DummyExpedition() from_data = None @@ -131,3 +136,50 @@ class adcp_config: assert np.isclose(obs_value, exp_value), ( f"Observation incorrect {vert_loc=} {i=} {var=} {obs_value=} {exp_value=}." ) + + +def test_adcp_sensor_config_active_variables() -> None: + """active_variables() returns both U and V when VELOCITY is enabled.""" + config_with = ADCPConfig( + max_depth_meter=-1000.0, + num_bins=40, + period_minutes=5.0, + sensors=[SensorConfig(sensor_type=SensorType.VELOCITY)], + ) + assert config_with.active_variables() == {"U": "uo", "V": "vo"} + + +def test_adcp_sensor_config_yaml() -> None: + """ADCPConfig sensors survive YAML serialisation.""" + config = ADCPConfig( + max_depth_meter=-1000.0, + num_bins=40, + period_minutes=5.0, + sensors=[SensorConfig(sensor_type=SensorType.VELOCITY)], + ) + dumped = config.model_dump(by_alias=True) + loaded = ADCPConfig.model_validate(dumped) + assert len(loaded.sensors) == 1 + assert loaded.sensors[0].sensor_type == SensorType.VELOCITY + assert loaded.sensors[0].enabled is True + + +def test_adcp_config_default_sensors(): + """ADCPConfig defaults to VELOCITY.""" + config = ADCPConfig( + max_depth_meter=-500.0, + num_bins=30, + period_minutes=30.0, + ) + assert config.sensors[0].sensor_type is SensorType.VELOCITY + + +def test_adcp_config_unsupported_sensor_rejected(): + """Unsupported sensor on ADCP is rejected.""" + with pytest.raises(pydantic.ValidationError, match="does not support"): + ADCPConfig( + max_depth_meter=-500.0, + num_bins=30, + period_minutes=30.0, + sensors=[SensorConfig(sensor_type=SensorType.TEMPERATURE)], + ) diff --git a/tests/instruments/test_argo_float.py b/tests/instruments/test_argo_float.py index 66331d64..403e0457 100644 --- a/tests/instruments/test_argo_float.py +++ b/tests/instruments/test_argo_float.py @@ -3,12 +3,20 @@ from datetime import datetime, timedelta import numpy as np +import pydantic +import pytest import xarray as xr from parcels import FieldSet from virtualship.instruments.argo_float import ArgoFloat, ArgoFloatInstrument +from virtualship.instruments.sensors import SensorType from virtualship.models import Location, Spacetime -from virtualship.models.expedition import Waypoint +from virtualship.models.expedition import ( + ArgoFloatConfig, + InstrumentsConfig, + SensorConfig, + Waypoint, +) def test_simulate_argo_floats(tmpdir) -> None: @@ -76,9 +84,22 @@ class schedule: ), ] - class instruments_config: - class argo_float_config: - lifetime = LIFETIME + instruments_config = InstrumentsConfig( + argo_float_config=ArgoFloatConfig( + min_depth_meter=0.0, + max_depth_meter=MAX_DEPTH, + drift_depth_meter=DRIFT_DEPTH, + vertical_speed_meter_per_second=VERTICAL_SPEED, + cycle_days=CYCLE_DAYS, + drift_days=DRIFT_DAYS, + lifetime=LIFETIME, + stationkeeping_time_minutes=10, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.SALINITY), + ], + ) + ) expedition = DummyExpedition() from_data = None @@ -96,3 +117,115 @@ class argo_float_config: assert len(results.trajectory) == len(argo_floats) for var in ["lon", "lat", "z", "temperature", "salinity"]: assert var in results, f"Results don't contain {var}" + + +def test_argo_float_disabled_sensor(tmpdir) -> None: + """Variables for disabled sensors must not appear in the zarr output.""" + base_time = datetime.strptime("1950-01-01", "%Y-%m-%d") + + DRIFT_DEPTH = -1000 + MAX_DEPTH = -2000 + VERTICAL_SPEED = -0.10 + CYCLE_DAYS = 10 + DRIFT_DAYS = 9 + LIFETIME = timedelta(days=1) + + v = np.full((2, 2, 2), 1.0) + u = np.full((2, 2, 2), 1.0) + t = np.full((2, 2, 2), 1.0) + bathy = np.full((2, 2), -5000.0) + + # only temperature fieldset, no salinity + fieldset = FieldSet.from_data( + {"V": v, "U": u, "T": t}, + { + "lon": np.array([0.0, 10.0]), + "lat": np.array([0.0, 10.0]), + "time": [ + np.datetime64(base_time + timedelta(seconds=0)), + np.datetime64(base_time + timedelta(hours=4)), + ], + }, + ) + fieldset.add_field( + FieldSet.from_data( + {"bathymetry": bathy}, + {"lon": np.array([0.0, 10.0]), "lat": np.array([0.0, 10.0])}, + ).bathymetry + ) + + argo_floats = [ + ArgoFloat( + spacetime=Spacetime(location=Location(latitude=0, longitude=0), time=0), + min_depth=0.0, + max_depth=MAX_DEPTH, + drift_depth=DRIFT_DEPTH, + vertical_speed=VERTICAL_SPEED, + cycle_days=CYCLE_DAYS, + drift_days=DRIFT_DAYS, + ) + ] + + class DummyExpedition: + class schedule: + waypoints = [Waypoint(location=Location(1, 2), time=base_time)] + + instruments_config = InstrumentsConfig( + argo_float_config=ArgoFloatConfig( + min_depth_meter=0.0, + max_depth_meter=MAX_DEPTH, + drift_depth_meter=DRIFT_DEPTH, + vertical_speed_meter_per_second=VERTICAL_SPEED, + cycle_days=CYCLE_DAYS, + drift_days=DRIFT_DAYS, + lifetime=LIFETIME, + stationkeeping_time_minutes=10, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE) + ], # SALINITY omitted = disabled + ) + ) + + expedition = DummyExpedition() + argo_instrument = ArgoFloatInstrument(expedition, None) + out_path = tmpdir.join("out_disabled.zarr") + argo_instrument.load_input_data = lambda: fieldset + argo_instrument.simulate(argo_floats, out_path) + + results = xr.open_zarr(out_path) + assert "temperature" in results, "Enabled sensor variable must be present" + assert "salinity" not in results, ( + "Disabled sensor variable must be absent from output" + ) + + +def test_argo_config_default_sensors(): + """ArgoFloatConfig defaults to TEMPERATURE + SALINITY.""" + config = ArgoFloatConfig( + min_depth_meter=0.0, + max_depth_meter=-2000, + drift_depth_meter=-1000, + vertical_speed_meter_per_second=-0.10, + cycle_days=10, + drift_days=9, + lifetime=timedelta(days=30), + stationkeeping_time_minutes=10, + ) + types = {sc.sensor_type for sc in config.sensors} + assert types == {SensorType.TEMPERATURE, SensorType.SALINITY} + + +def test_argo_config_unsupported_sensor_rejected(): + """Unsupported sensor on ArgoFloat is rejected.""" + with pytest.raises(pydantic.ValidationError, match="does not support"): + ArgoFloatConfig( + min_depth_meter=0.0, + max_depth_meter=-2000, + drift_depth_meter=-1000, + vertical_speed_meter_per_second=-0.10, + cycle_days=10, + drift_days=9, + lifetime=timedelta(days=30), + stationkeeping_time_minutes=10, + sensors=[SensorConfig(sensor_type=SensorType.OXYGEN)], + ) diff --git a/tests/instruments/test_base.py b/tests/instruments/test_base.py index fea43f02..bbcfea44 100644 --- a/tests/instruments/test_base.py +++ b/tests/instruments/test_base.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock, patch +import pytest + from virtualship.instruments.base import Instrument from virtualship.instruments.types import InstrumentType from virtualship.utils import get_instrument_class @@ -14,6 +16,8 @@ def test_all_instruments_have_instrument_class(): class DummyInstrument(Instrument): """Minimal concrete Instrument for testing.""" + sensor_kernels = {} # noqa + def simulate(self, data_dir, measurements, out_path): """Dummy simulate implementation for test.""" self.simulate_called = True @@ -147,3 +151,12 @@ def test_load_input_data_error(monkeypatch): dummy.load_input_data() except virtualship.errors.CopernicusCatalogueError as e: assert "Failed to load input data" in str(e) + + +def test_instrument_subclass_without_sensor_kernels_error(): + """Defining a concrete Instrument subclass without sensor_kernels raises TypeError.""" + with pytest.raises(TypeError, match="sensor_kernels"): + + class ErrorInstrument(Instrument): + def simulate(self, data_dir, measurements, out_path): + pass diff --git a/tests/instruments/test_ctd.py b/tests/instruments/test_ctd.py index 954d0b78..4fd6e28d 100644 --- a/tests/instruments/test_ctd.py +++ b/tests/instruments/test_ctd.py @@ -7,12 +7,21 @@ import datetime import numpy as np +import pydantic +import pytest import xarray as xr from parcels import Field, FieldSet from virtualship.instruments.ctd import CTD, CTDInstrument +from virtualship.instruments.sensors import SensorType +from virtualship.instruments.types import InstrumentType from virtualship.models import Location, Spacetime -from virtualship.models.expedition import Waypoint +from virtualship.models.expedition import ( + CTDConfig, + InstrumentsConfig, + SensorConfig, + Waypoint, +) def test_simulate_ctds(tmpdir) -> None: @@ -113,6 +122,18 @@ class schedule: ), ] + instruments_config = InstrumentsConfig( + ctd_config=CTDConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.SALINITY), + ], + ) + ) + expedition = DummyExpedition() from_data = None @@ -146,3 +167,135 @@ class schedule: assert np.isclose(obs_value, exp_value), ( f"Observation incorrect {ctd_i=} {loc=} {var=} {obs_value=} {exp_value=}." ) + + +def test_ctd_sensor_config_active_variables() -> None: + """active_variables() only returns variables for enabled sensors.""" + config_both = CTDConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.SALINITY), + ], + ) + assert config_both.active_variables() == {"T": "thetao", "S": "so"} + + config_temp_only = CTDConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE) + ], # SALINITY absent = disabled + ) + assert config_temp_only.active_variables() == {"T": "thetao"} + + +def test_ctd_sensor_config_yaml() -> None: + """CTDConfig sensors survive YAML serialisation.""" + config = CTDConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE) + ], # SALINITY omitted = disabled + ) + dumped = config.model_dump(by_alias=True) + loaded = CTDConfig.model_validate(dumped) + + assert len(loaded.sensors) == 1 + assert loaded.sensors[0].sensor_type == SensorType.TEMPERATURE + assert loaded.sensors[0].enabled is True + + +def test_ctd_disabled_sensor_absent(tmpdir) -> None: + """Variables for disabled sensors must not appear in the zarr output.""" + base_time = datetime.datetime.strptime("1950-01-01", "%Y-%m-%d") + + ctds = [ + CTD( + spacetime=Spacetime( + location=Location(latitude=0, longitude=0), + time=base_time, + ), + min_depth=0, + max_depth=-20, + ), + ] + + # Only temperature field, no salinty + t = np.full((2, 2, 2), 5.0) + fieldset = FieldSet.from_data( + {"T": t}, + { + "lon": np.array([0.0, 1.0]), + "lat": np.array([0.0, 1.0]), + "time": [ + np.datetime64(base_time + datetime.timedelta(seconds=0)), + np.datetime64(base_time + datetime.timedelta(hours=4)), + ], + }, + ) + fieldset.add_field(Field("bathymetry", [-1000], lon=0, lat=0)) + + class DummyExpedition: + class schedule: + waypoints = [Waypoint(location=Location(1, 2), time=base_time)] + + instruments_config = InstrumentsConfig( + ctd_config=CTDConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE) + ], # SALINITY omitted = disabled + ) + ) + + expedition = DummyExpedition() + ctd_instrument = CTDInstrument(expedition, None) + out_path = tmpdir.join("out_disabled.zarr") + ctd_instrument.load_input_data = lambda: fieldset + ctd_instrument.simulate(ctds, out_path) + + results = xr.open_zarr(out_path) + assert "temperature" in results, "Enabled sensor variable must be present" + assert "salinity" not in results, ( + "Disabled sensor variable must be absent from output" + ) + + +def test_ctd_supported_sensors(): + """CTD supports TEMPERATURE and SALINITY.""" + from virtualship.utils import get_supported_sensors + + assert get_supported_sensors(InstrumentType.CTD) == frozenset( + {SensorType.TEMPERATURE, SensorType.SALINITY} + ) + + +def test_ctd_config_default_sensors(): + """CTDConfig defaults to TEMPERATURE + SALINITY.""" + config = CTDConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + ) + types = {sc.sensor_type for sc in config.sensors} + assert types == {SensorType.TEMPERATURE, SensorType.SALINITY} + + +# TODO: may need to be removed if add ADCP to CTDs in future PR... +def test_ctd_config_unsupported_sensor_rejected(): + """Unsupported sensor on CTD is rejected.""" + with pytest.raises(pydantic.ValidationError, match="does not support"): + CTDConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[SensorConfig(sensor_type=SensorType.VELOCITY)], + ) diff --git a/tests/instruments/test_ctd_bgc.py b/tests/instruments/test_ctd_bgc.py index 39fa6c1f..ad485617 100644 --- a/tests/instruments/test_ctd_bgc.py +++ b/tests/instruments/test_ctd_bgc.py @@ -7,12 +7,20 @@ import datetime import numpy as np +import pydantic +import pytest import xarray as xr from parcels import Field, FieldSet from virtualship.instruments.ctd_bgc import CTD_BGC, CTD_BGCInstrument +from virtualship.instruments.sensors import SensorType from virtualship.models import Location, Spacetime -from virtualship.models.expedition import Waypoint +from virtualship.models.expedition import ( + CTD_BGCConfig, + InstrumentsConfig, + SensorConfig, + Waypoint, +) def test_simulate_ctd_bgcs(tmpdir) -> None: @@ -174,6 +182,23 @@ class schedule: ), ] + instruments_config = InstrumentsConfig( + ctd_bgc_config=CTD_BGCConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.OXYGEN), + SensorConfig(sensor_type=SensorType.CHLOROPHYLL), + SensorConfig(sensor_type=SensorType.NITRATE), + SensorConfig(sensor_type=SensorType.PHOSPHATE), + SensorConfig(sensor_type=SensorType.PH), + SensorConfig(sensor_type=SensorType.PHYTOPLANKTON), + SensorConfig(sensor_type=SensorType.PRIMARY_PRODUCTION), + ], + ) + ) + expedition = DummyExpedition() from_data = None @@ -216,3 +241,57 @@ class schedule: assert np.isclose(obs_value, exp_value), ( f"Observation incorrect {ctd_i=} {loc=} {var=} {obs_value=} {exp_value=}." ) + + +def test_ctd_bgc_sensor_config_active_variables() -> None: + """active_variables() only returns variables for enabled sensors.""" + config_all = CTD_BGCConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.OXYGEN), + SensorConfig(sensor_type=SensorType.CHLOROPHYLL), + SensorConfig(sensor_type=SensorType.NITRATE), + SensorConfig(sensor_type=SensorType.PHOSPHATE), + SensorConfig(sensor_type=SensorType.PH), + SensorConfig(sensor_type=SensorType.PHYTOPLANKTON), + SensorConfig(sensor_type=SensorType.PRIMARY_PRODUCTION), + ], + ) + assert config_all.active_variables() == { + "o2": "o2", + "chl": "chl", + "no3": "no3", + "po4": "po4", + "ph": "ph", + "phyc": "phyc", + "nppv": "nppv", + } + + config_o2_only = CTD_BGCConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.OXYGEN) + ], # all others omitted = disabled + ) + assert config_o2_only.active_variables() == {"o2": "o2"} + + +def test_ctd_bgc_sensor_config_yaml() -> None: + """CTD_BGCConfig sensors survive YAML serialisation.""" + config = CTD_BGCConfig( + stationkeeping_time_minutes=50, + min_depth_meter=-11.0, + max_depth_meter=-2000.0, + sensors=[ + SensorConfig(sensor_type=SensorType.OXYGEN) + ], # CHLOROPHYLL and others omitted = disabled + ) + dumped = config.model_dump(by_alias=True) + loaded = CTD_BGCConfig.model_validate(dumped) + assert len(loaded.sensors) == 1 + assert loaded.sensors[0].sensor_type == SensorType.OXYGEN + assert loaded.sensors[0].enabled is True diff --git a/tests/instruments/test_drifter.py b/tests/instruments/test_drifter.py index 0dc72597..56f3257e 100644 --- a/tests/instruments/test_drifter.py +++ b/tests/instruments/test_drifter.py @@ -4,12 +4,20 @@ from typing import ClassVar import numpy as np +import pydantic +import pytest import xarray as xr from parcels import FieldSet from virtualship.instruments.drifter import Drifter, DrifterInstrument +from virtualship.instruments.sensors import SensorType from virtualship.models import Location, Spacetime -from virtualship.models.expedition import Waypoint +from virtualship.models.expedition import ( + DrifterConfig, + InstrumentsConfig, + SensorConfig, + Waypoint, +) BASE_TIME = datetime.datetime.strptime("1950-01-01", "%Y-%m-%d") LIFETIME = datetime.timedelta(days=1) @@ -31,10 +39,14 @@ class schedule: ), ] - class instruments_config: - class drifter_config: - lifetime = LIFETIME - depth_meter = DEPLOY_DEPTH + instruments_config = InstrumentsConfig( + drifter_config=DrifterConfig( + lifetime=LIFETIME, + depth_meter=DEPLOY_DEPTH, + stationkeeping_time_minutes=10, + sensors=[SensorConfig(sensor_type=SensorType.TEMPERATURE)], + ) + ) return DummyExpedition() @@ -188,3 +200,35 @@ def test_drifter_depths(tmpdir) -> None: assert drifter_surface.temperature[0] != drifter_depth.temperature[0], ( "Surface and deeper drifter should have different temperature measurements" ) + + +def test_drifter_disabled_sensor_absent_from_output(tmpdir) -> None: + """A DrifterConfig with no enabled sensors should be rejected at construction time.""" + with pytest.raises(pydantic.ValidationError, match="no enabled sensors"): + DrifterConfig( + lifetime=LIFETIME, + depth_meter=DEPLOY_DEPTH, + stationkeeping_time_minutes=10, + sensors=[], + ) + + +def test_drifter_config_default_sensors(): + """DrifterConfig defaults to TEMPERATURE.""" + config = DrifterConfig( + lifetime=LIFETIME, + depth_meter=DEPLOY_DEPTH, + stationkeeping_time_minutes=10, + ) + assert config.sensors[0].sensor_type is SensorType.TEMPERATURE + + +def test_drifter_config_unsupported_sensor_rejected(): + """Unsupported sensor on Drifter is rejected.""" + with pytest.raises(pydantic.ValidationError, match="does not support"): + DrifterConfig( + lifetime=LIFETIME, + depth_meter=DEPLOY_DEPTH, + stationkeeping_time_minutes=10, + sensors=[SensorConfig(sensor_type=SensorType.VELOCITY)], + ) diff --git a/tests/instruments/test_sensors.py b/tests/instruments/test_sensors.py new file mode 100644 index 00000000..b79facb7 --- /dev/null +++ b/tests/instruments/test_sensors.py @@ -0,0 +1,141 @@ +import pydantic +import pytest + +from virtualship.instruments.sensors import ( + SensorType, +) +from virtualship.instruments.types import InstrumentType +from virtualship.models.expedition import SENSOR_REGISTRY, SensorConfig +from virtualship.utils import get_supported_sensors + +EXPECTED_SENSOR_MEMBERS = { + "TEMPERATURE", + "SALINITY", + "VELOCITY", + "OXYGEN", + "CHLOROPHYLL", + "NITRATE", + "PHOSPHATE", + "PH", + "PHYTOPLANKTON", + "PRIMARY_PRODUCTION", +} + + +def test_sensor_registry_keys_match_sensor_type(): + """SENSOR_REGISTRY keys must be exactly the set of SensorType members.""" + assert set(SENSOR_REGISTRY().keys()) == set(SensorType) + + +@pytest.mark.parametrize( + "sensor_type", + [ + SensorType.OXYGEN, + SensorType.CHLOROPHYLL, + SensorType.NITRATE, + SensorType.PHOSPHATE, + SensorType.PH, + SensorType.PHYTOPLANKTON, + SensorType.PRIMARY_PRODUCTION, + ], +) +def test_sensor_registry_bgc_entries_category(sensor_type): + """All BGC sensors must have category 'bgc'.""" + assert SENSOR_REGISTRY()[sensor_type].category == "bgc" + + +def test_sensor_registry_unique_fs_keys(): + """No two sensors should share an fs_key.""" + fs_keys = [meta.fs_key for meta in SENSOR_REGISTRY().values()] + assert len(fs_keys) == len(set(fs_keys)), ( + "Duplicate fs_key found in SENSOR_REGISTRY" + ) + + +def test_sensor_type_all_members_exist(): + """All expected SensorType members are present.""" + actual = {m.name for m in SensorType} + assert actual == EXPECTED_SENSOR_MEMBERS + + +def test_sensor_type_lookup_by_value(): + """Can construct a SensorType from its string value.""" + assert SensorType("SALINITY") is SensorType.SALINITY + + +def test_sensor_type_invalid_value_error(): + """Invalid string raises ValueError.""" + with pytest.raises(ValueError): + SensorType("NOT_A_SENSOR") + + +def test_all_allowlists_are_frozenset(): + """All per-instrument supported sensor sets must be frozensets (immutable).""" + for itype in InstrumentType: + allowlist = get_supported_sensors(itype) + assert isinstance(allowlist, frozenset) + + +def test_sensor_config_basic_construction(): + """Standard construction with SensorType enum.""" + sc = SensorConfig(sensor_type=SensorType.TEMPERATURE) + assert sc.sensor_type is SensorType.TEMPERATURE + assert sc.enabled is True + + +def test_sensor_config_disabled(): + """Can explicitly set enabled=False.""" + sc = SensorConfig(sensor_type=SensorType.SALINITY, enabled=False) + assert sc.enabled is False + + +def test_sensor_config_from_string_shorthand(): + """A bare string should be accepted as shorthand.""" + sc = SensorConfig.model_validate("TEMPERATURE") + assert sc.sensor_type is SensorType.TEMPERATURE + assert sc.enabled is True + + +def test_sensor_config_invalid_string_error(): + """An unknown sensor name should raise error.""" + with pytest.raises(pydantic.ValidationError): + SensorConfig.model_validate("NOT_REAL") + + +def test_serialize_sensor_list_disabled_excluded(): + """Disabled sensors are excluded from serialisation.""" + sensors = [ + SensorConfig(sensor_type=SensorType.TEMPERATURE, enabled=True), + SensorConfig(sensor_type=SensorType.SALINITY, enabled=False), + ] + assert SensorConfig.serialize_list(sensors) == ["TEMPERATURE"] + + +def test_check_sensor_compatibility_unsupported_error(): + """Unsupported sensor fails.""" + sensors = [SensorConfig(sensor_type=SensorType.OXYGEN)] + with pytest.raises(ValueError, match="does not support sensor"): + SensorConfig.check_compatibility(sensors, InstrumentType.DRIFTER, "Drifter") + + +def test_check_sensor_compatibility_all_disabled_error(): + """All sensors disabled fails.""" + sensors = [SensorConfig(sensor_type=SensorType.TEMPERATURE, enabled=False)] + with pytest.raises(ValueError, match="no enabled sensors"): + SensorConfig.check_compatibility(sensors, InstrumentType.DRIFTER, "Drifter") + + +def test_check_sensor_compatibility_empty_error(): + """Empty sensor list fails.""" + with pytest.raises(ValueError, match="no enabled sensors"): + SensorConfig.check_compatibility([], InstrumentType.DRIFTER, "Drifter") + + +def test_check_sensor_compatibility_mixed_error(): + """Mix of valid and invalid sensors fails.""" + sensors = [ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.OXYGEN), + ] + with pytest.raises(ValueError, match="does not support"): + SensorConfig.check_compatibility(sensors, InstrumentType.DRIFTER, "Drifter") diff --git a/tests/instruments/test_ship_underwater_st.py b/tests/instruments/test_ship_underwater_st.py index 3f1aae65..9c879d48 100644 --- a/tests/instruments/test_ship_underwater_st.py +++ b/tests/instruments/test_ship_underwater_st.py @@ -3,12 +3,20 @@ import datetime import numpy as np +import pydantic +import pytest import xarray as xr from parcels import FieldSet from virtualship.instruments.ship_underwater_st import Underwater_STInstrument +from virtualship.instruments.sensors import SensorType from virtualship.models import Location, Spacetime -from virtualship.models.expedition import Waypoint +from virtualship.models.expedition import ( + InstrumentsConfig, + SensorConfig, + ShipUnderwaterSTConfig, + Waypoint, +) def test_simulate_ship_underwater_st(tmpdir) -> None: @@ -24,15 +32,15 @@ def test_simulate_ship_underwater_st(tmpdir) -> None: # expected observations at sample points expected_obs = [ { - "S": 5, - "T": 6, + "salinity": 5, + "temperature": 6, "lat": sample_points[0].location.lat, "lon": sample_points[0].location.lon, "time": base_time + datetime.timedelta(seconds=0), }, { - "S": 7, - "T": 8, + "salinity": 7, + "temperature": 8, "lat": sample_points[1].location.lat, "lon": sample_points[1].location.lon, "time": base_time + datetime.timedelta(seconds=1), @@ -42,12 +50,12 @@ def test_simulate_ship_underwater_st(tmpdir) -> None: # create fieldset based on the expected observations # indices are time, latitude, longitude salinity = np.zeros((2, 2, 2)) - salinity[0, 0, 0] = expected_obs[0]["S"] - salinity[1, 1, 1] = expected_obs[1]["S"] + salinity[0, 0, 0] = expected_obs[0]["salinity"] + salinity[1, 1, 1] = expected_obs[1]["salinity"] temperature = np.zeros((2, 2, 2)) - temperature[0, 0, 0] = expected_obs[0]["T"] - temperature[1, 1, 1] = expected_obs[1]["T"] + temperature[0, 0, 0] = expected_obs[0]["temperature"] + temperature[1, 1, 1] = expected_obs[1]["temperature"] fieldset = FieldSet.from_data( { @@ -79,6 +87,16 @@ class schedule: ), ] + instruments_config = InstrumentsConfig( + ship_underwater_st_config=ShipUnderwaterSTConfig( + period_minutes=5.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.SALINITY), + ], + ) + ) + expedition = DummyExpedition() from_data = None @@ -103,9 +121,62 @@ class schedule: zip(results.sel(trajectory=traj).obs, expected_obs, strict=True) ): obs = results.sel(trajectory=traj, obs=obs_i) - for var in ["S", "T", "lat", "lon"]: + for var in ["salinity", "temperature", "lat", "lon"]: obs_value = obs[var].values.item() exp_value = exp[var] assert np.isclose(obs_value, exp_value), ( f"Observation incorrect {i=} {var=} {obs_value=} {exp_value=}." ) + + +def test_ship_underwater_st_sensor_config_active_variables() -> None: + """active_variables() only returns variables for enabled sensors.""" + config_both = ShipUnderwaterSTConfig( + period_minutes=5.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE), + SensorConfig(sensor_type=SensorType.SALINITY), + ], + ) + assert config_both.active_variables() == {"T": "thetao", "S": "so"} + + config_temp_only = ShipUnderwaterSTConfig( + period_minutes=5.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE) + ], # SALINITY omitted = disabled + ) + assert config_temp_only.active_variables() == {"T": "thetao"} + + +def test_ship_underwater_st_sensor_config_yaml() -> None: + """ShipUnderwaterSTConfig sensors survive YAML serialisation.""" + config = ShipUnderwaterSTConfig( + period_minutes=5.0, + sensors=[ + SensorConfig(sensor_type=SensorType.TEMPERATURE) + ], # SALINITY omitted = disabled + ) + dumped = config.model_dump(by_alias=True) + loaded = ShipUnderwaterSTConfig.model_validate(dumped) + assert len(loaded.sensors) == 1 + assert loaded.sensors[0].sensor_type == SensorType.TEMPERATURE + assert loaded.sensors[0].enabled is True + + +def test_underwater_st_config_default_sensors(): + """ShipUnderwaterSTConfig defaults to TEMPERATURE + SALINITY.""" + config = ShipUnderwaterSTConfig( + period_minutes=5.0, + ) + types = {sc.sensor_type for sc in config.sensors} + assert types == {SensorType.TEMPERATURE, SensorType.SALINITY} + + +def test_underwater_st_config_unsupported_sensor_rejected(): + """Unsupported sensor on Underwater ST is rejected.""" + with pytest.raises(pydantic.ValidationError, match="does not support"): + ShipUnderwaterSTConfig( + period_minutes=5.0, + sensors=[SensorConfig(sensor_type=SensorType.OXYGEN)], + ) diff --git a/tests/instruments/test_xbt.py b/tests/instruments/test_xbt.py index c6a36631..0ac3a7cb 100644 --- a/tests/instruments/test_xbt.py +++ b/tests/instruments/test_xbt.py @@ -7,12 +7,20 @@ import datetime import numpy as np +import pydantic +import pytest import xarray as xr from parcels import Field, FieldSet from virtualship.instruments.xbt import XBT, XBTInstrument +from virtualship.instruments.sensors import SensorType from virtualship.models import Location, Spacetime -from virtualship.models.expedition import Waypoint +from virtualship.models.expedition import ( + InstrumentsConfig, + SensorConfig, + Waypoint, + XBTConfig, +) def test_simulate_xbts(tmpdir) -> None: @@ -107,6 +115,16 @@ class schedule: ), ] + instruments_config = InstrumentsConfig( + xbt_config=XBTConfig( + min_depth_meter=-2.0, + max_depth_meter=-285.0, + fall_speed_meter_per_second=6.7, + deceleration_coefficient=0.00225, + sensors=[SensorConfig(sensor_type=SensorType.TEMPERATURE)], + ) + ) + expedition = DummyExpedition() from_data = None @@ -139,3 +157,54 @@ class schedule: assert np.isclose(obs_value, exp_value), ( f"Observation incorrect {xbt_i=} {loc=} {var=} {obs_value=} {exp_value=}." ) + + +def test_xbt_sensor_config_active_variables() -> None: + """active_variables() only returns variables for enabled sensors.""" + config_with_temp = XBTConfig( + min_depth_meter=-2.0, + max_depth_meter=-285.0, + fall_speed_meter_per_second=6.7, + deceleration_coefficient=0.00225, + sensors=[SensorConfig(sensor_type=SensorType.TEMPERATURE)], + ) + assert config_with_temp.active_variables() == {"T": "thetao"} + + +def test_xbt_sensor_config_yaml() -> None: + """XBTConfig sensors survive YAML serialisation.""" + config = XBTConfig( + min_depth_meter=-2.0, + max_depth_meter=-285.0, + fall_speed_meter_per_second=6.7, + deceleration_coefficient=0.00225, + sensors=[SensorConfig(sensor_type=SensorType.TEMPERATURE)], + ) + dumped = config.model_dump(by_alias=True) + loaded = XBTConfig.model_validate(dumped) + assert len(loaded.sensors) == 1 + assert loaded.sensors[0].sensor_type == SensorType.TEMPERATURE + assert loaded.sensors[0].enabled is True + + +def test_xbt_config_default_sensors(): + """XBTConfig defaults to TEMPERATURE.""" + config = XBTConfig( + min_depth_meter=-2.0, + max_depth_meter=-285.0, + fall_speed_meter_per_second=6.7, + deceleration_coefficient=0.00225, + ) + assert config.sensors[0].sensor_type is SensorType.TEMPERATURE + + +def test_xbt_config_unsupported_sensor_rejected(): + """Unsupported sensor on XBT is rejected.""" + with pytest.raises(pydantic.ValidationError, match="does not support"): + XBTConfig( + min_depth_meter=-2.0, + max_depth_meter=-285.0, + fall_speed_meter_per_second=6.7, + deceleration_coefficient=0.00225, + sensors=[SensorConfig(sensor_type=SensorType.SALINITY)], + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8f9ec016..4860f2f6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,11 +4,12 @@ import numpy as np import pytest import xarray as xr -from parcels import FieldSet +from parcels import FieldSet, JITParticle, ScipyParticle, Variable import virtualship.utils +from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType -from virtualship.models.expedition import Expedition +from virtualship.models.expedition import Expedition, SensorConfig from virtualship.models.location import Location from virtualship.utils import ( PROJECTION, @@ -19,6 +20,7 @@ _select_product_id, _start_end_in_product_timerange, add_dummy_UV, + build_particle_class_from_sensors, get_example_expedition, ) @@ -360,3 +362,50 @@ def test_calc_wp_stationkeeping_time_no_instruments(expedition): assert stationkeeping_null == stationkeeping_emptylist # are equivalent assert stationkeeping_null == datetime.timedelta(0) # at least one is 0 time + + +# helper +def _make_sensors(*sensor_types, enabled=True): + """Helper to build a list of SensorConfig from SensorType values.""" + return [SensorConfig(sensor_type=st, enabled=enabled) for st in sensor_types] + + +def test_build_basic_particle_class(): + """Build basic particle class with T+S sensors and nonsensor variables.""" + nonsensor = [Variable("cycle_phase", dtype=np.int32, initial=0)] + sensors = _make_sensors(SensorType.TEMPERATURE, SensorType.SALINITY) + + ParticleClass = build_particle_class_from_sensors(sensors, nonsensor, JITParticle) + assert issubclass(ParticleClass, JITParticle) + + +def test_build_particle_class_disabled_sensors_excluded(): + """Disabled sensors should not contribute variables.""" + nonsensor = [] + sensors = [ + SensorConfig(sensor_type=SensorType.TEMPERATURE, enabled=True), + SensorConfig(sensor_type=SensorType.SALINITY, enabled=False), + ] + + ParticleClass = build_particle_class_from_sensors(sensors, nonsensor, JITParticle) + assert hasattr(ParticleClass, "temperature") + assert not hasattr(ParticleClass, "salinity") + + +def test_build_particle_class_velocity_adds_U_V(): + """VELOCITY sensor should add both U and V particle variables.""" + nonsensor = [] + sensors = _make_sensors(SensorType.VELOCITY) + + ParticleClass = build_particle_class_from_sensors(sensors, nonsensor, JITParticle) + assert hasattr(ParticleClass, "U") + assert hasattr(ParticleClass, "V") + + +def test_build_particle_class_scipy_base(): + """Should also work with ScipyParticle as the base class.""" + nonsensor = [] + sensors = _make_sensors(SensorType.TEMPERATURE) + + ParticleClass = build_particle_class_from_sensors(sensors, nonsensor, ScipyParticle) + assert issubclass(ParticleClass, ScipyParticle)