Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion kwave/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional

import numpy as np

Expand Down Expand Up @@ -119,3 +119,128 @@ def append(self, val):
assert len(self.data) <= 2
self.data.append(val)
return self


@dataclass
class SimulationResult:
Comment thread
waltsims marked this conversation as resolved.
"""
Structured return type for kWave simulation results.

Contains all possible fields that can be returned by the kWave C++ binaries.
Fields are populated based on the sensor.record configuration.
"""

# Grid information (always present)
Nx: int
Ny: int
Nz: int
Nt: int
pml_x_size: int
pml_y_size: int
pml_z_size: int
axisymmetric_flag: bool

# Pressure fields (optional - based on sensor.record). Field names match
# the user-facing sensor.record keys, not the binary's HDF5 dataset names
# (e.g. sensor.record=("p",) → result.p; binary dataset is "p_raw").
p: Optional[np.ndarray] = None
p_max: Optional[np.ndarray] = None
p_min: Optional[np.ndarray] = None
p_rms: Optional[np.ndarray] = None
p_max_all: Optional[np.ndarray] = None
p_min_all: Optional[np.ndarray] = None
p_final: Optional[np.ndarray] = None

# Velocity fields (optional - based on sensor.record). `u` is the
# collocated (pressure-grid) velocity; `u_staggered` is the mid-cell
# variant. The legacy C++ binary inverts this naming (its "u" is
# mid-cell and "u_non_staggered" is collocated); from_dotdict
# translates so the user-facing API always uses the modern convention.
u: Optional[np.ndarray] = None
u_max: Optional[np.ndarray] = None
u_min: Optional[np.ndarray] = None
u_rms: Optional[np.ndarray] = None
u_max_all: Optional[np.ndarray] = None
u_min_all: Optional[np.ndarray] = None
u_final: Optional[np.ndarray] = None
u_staggered: Optional[np.ndarray] = None

# Intensity fields (optional - based on sensor.record)
I_avg: Optional[np.ndarray] = None
I: Optional[np.ndarray] = None

def __getitem__(self, key: str):
"""
Enable dictionary-style access for backward compatibility.

Args:
key: Field name to access

Returns:
Value of the field

Raises:
KeyError: If the field does not exist
"""
if hasattr(self, key):
return getattr(self, key)
raise KeyError(f"'{key}' field not found in SimulationResult")

def __contains__(self, key: str) -> bool:
"""
Enable dictionary-style membership testing for backward compatibility.

Args:
key: Field name to check

Returns:
True if the field exists, False otherwise
"""
return hasattr(self, key)
Comment on lines +172 to +199
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 __contains__ and __getitem__ use hasattr, which returns True for any attribute on the object — including methods like from_dotdict, __init__, and dunder attributes. A caller doing 'from_dotdict' in result gets True and result['from_dotdict'] returns the classmethod rather than raising KeyError. The intent is backward compatibility for sensor data fields, so the check should be restricted to declared dataclass fields only.

Suggested change
def __getitem__(self, key: str):
"""
Enable dictionary-style access for backward compatibility.
Args:
key: Field name to access
Returns:
Value of the field
Raises:
KeyError: If the field does not exist
"""
if hasattr(self, key):
return getattr(self, key)
raise KeyError(f"'{key}' field not found in SimulationResult")
def __contains__(self, key: str) -> bool:
"""
Enable dictionary-style membership testing for backward compatibility.
Args:
key: Field name to check
Returns:
True if the field exists, False otherwise
"""
return hasattr(self, key)
def __getitem__(self, key: str):
"""
Enable dictionary-style access for backward compatibility.
Args:
key: Field name to access
Returns:
Value of the field
Raises:
KeyError: If the field does not exist
"""
import dataclasses
field_names = {f.name for f in dataclasses.fields(self)}
if key in field_names:
return getattr(self, key)
raise KeyError(f"'{key}' field not found in SimulationResult")
def __contains__(self, key: str) -> bool:
"""
Enable dictionary-style membership testing for backward compatibility.
Args:
key: Field name to check
Returns:
True if the field exists, False otherwise
"""
import dataclasses
return key in {f.name for f in dataclasses.fields(self)}

Comment on lines +172 to +199
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incomplete dict-compatibility bridge breaks .get() callers

SimulationResult implements __getitem__ and __contains__ for backward compatibility, but omits get(), keys(), items(), values(), and __iter__. Any existing code that calls result.get("p", None), iterates for k in result, or calls result.keys() will raise AttributeError at runtime — a silent regression for users migrating from the old dotdict return type. The PR description claims the return-type change is non-breaking, but without these methods the compatibility surface is materially incomplete.


@classmethod
def from_dotdict(cls, data: dict) -> "SimulationResult":
"""
Create SimulationResult from dotdict returned by parse_executable_output.

Args:
data: Dictionary containing simulation results from HDF5 file

Returns:
SimulationResult instance with all available fields populated
"""
return cls(
# Grid information
Nx=int(data.get("Nx", 0)),
Ny=int(data.get("Ny", 0)),
Nz=int(data.get("Nz", 0)),
Nt=int(data.get("Nt", 0)),
pml_x_size=int(data.get("pml_x_size", 0)),
pml_y_size=int(data.get("pml_y_size", 0)),
pml_z_size=int(data.get("pml_z_size", 0)),
axisymmetric_flag=bool(data.get("axisymmetric_flag", False)),
# Pressure fields — the binary writes HDF5 datasets named after
# sensor.record keys (e.g. /p, /p_max, /p_final), not the --p_raw
# CLI flag name, so we look up by the user-facing names.
p=data.get("p"),
p_max=data.get("p_max"),
p_min=data.get("p_min"),
p_rms=data.get("p_rms"),
p_max_all=data.get("p_max_all"),
p_min_all=data.get("p_min_all"),
p_final=data.get("p_final"),
# Velocity fields. Legacy C++ binary names are inverted vs the
# modern user-facing convention: binary's `u` is mid-cell
# (staggered), binary's `u_non_staggered` is collocated.
u=data.get("u_non_staggered"),
u_max=data.get("u_max"),
u_min=data.get("u_min"),
u_rms=data.get("u_rms"),
u_max_all=data.get("u_max_all"),
u_min_all=data.get("u_min_all"),
u_final=data.get("u_final"),
u_staggered=data.get("u"),
# Intensity fields
I_avg=data.get("I_avg"),
I=data.get("I"),
)
Comment on lines +201 to +246
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Velocity component fields silently dropped

_crop_pml in executor.py iterates over individual velocity-component field names (ux_final, uy_final, uz_final, ux_max_all, uy_max_all, uz_max_all, ux_min_all, uy_min_all, uz_min_all) because those are the keys the kWave C++ binary actually writes to the HDF5 output. from_dotdict maps u_final, u_max_all, and u_min_all instead — keys that are absent from the real output. data.get("u_final") therefore always returns None, and the cropped velocity data (already modified in the dotdict by _crop_pml) is silently discarded. Any user recording velocity fields will receive a SimulationResult where all those fields are None.

5 changes: 3 additions & 2 deletions kwave/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import h5py

import kwave
from kwave.data import SimulationResult
from kwave.options.simulation_execution_options import SimulationExecutionOptions
from kwave.utils.dotdictionary import dotdict

Expand All @@ -32,7 +33,7 @@ def _make_binary_executable(self):
raise FileNotFoundError(f"Binary not found at {binary_path}")
binary_path.chmod(binary_path.stat().st_mode | stat.S_IEXEC)

def run_simulation(self, input_filename: str, output_filename: str, options: list[str]) -> dotdict:
def run_simulation(self, input_filename: str, output_filename: str, options: list[str]) -> SimulationResult:
command = [str(self.execution_options.binary_path), "-i", input_filename, "-o", output_filename] + options

try:
Expand Down Expand Up @@ -68,7 +69,7 @@ def run_simulation(self, input_filename: str, output_filename: str, options: lis
if not self.simulation_options.pml_inside:
self._crop_pml(sensor_data)

return sensor_data
return SimulationResult.from_dotdict(sensor_data)

def _crop_pml(self, sensor_data: dotdict):
Nx = sensor_data["Nx"].item()
Expand Down
7 changes: 4 additions & 3 deletions kwave/kspaceFirstOrder2D.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import warnings
from typing import Union
from typing import Optional, Union

import numpy as np
from beartype import beartype as typechecker

from kwave.data import SimulationResult
from kwave.executor import Executor
from kwave.kgrid import kWaveGrid
from kwave.kmedium import kWaveMedium
Expand Down Expand Up @@ -86,7 +87,7 @@ def kspaceFirstOrder2DC(
medium: kWaveMedium,
simulation_options: SimulationOptions,
execution_options: SimulationExecutionOptions,
):
) -> Optional[SimulationResult]:
"""
2D time-domain simulation of wave propagation using C++ code.

Expand Down Expand Up @@ -146,7 +147,7 @@ def kspaceFirstOrder2D(
medium: kWaveMedium,
simulation_options: SimulationOptions,
execution_options: SimulationExecutionOptions,
):
) -> Optional[SimulationResult]:
"""
2D time-domain simulation of wave propagation using k-space pseudospectral method.

Expand Down
10 changes: 5 additions & 5 deletions kwave/kspaceFirstOrder3D.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings
from typing import Union
from typing import Optional, Union

import numpy as np
from deprecated import deprecated

from kwave.data import SimulationResult
from kwave.executor import Executor
from kwave.kgrid import kWaveGrid
from kwave.kmedium import kWaveMedium
Expand All @@ -27,7 +27,7 @@ def kspaceFirstOrder3DG(
medium: kWaveMedium,
simulation_options: SimulationOptions,
execution_options: SimulationExecutionOptions,
) -> Union[np.ndarray, dict]:
) -> Optional[SimulationResult]:
"""
3D time-domain simulation of wave propagation on a GPU using C++ CUDA code.

Expand Down Expand Up @@ -81,7 +81,7 @@ def kspaceFirstOrder3DC(
medium: kWaveMedium,
simulation_options: SimulationOptions,
execution_options: SimulationExecutionOptions,
):
) -> Optional[SimulationResult]:
"""
3D time-domain simulation of wave propagation using C++ code.

Expand Down Expand Up @@ -138,7 +138,7 @@ def kspaceFirstOrder3D(
simulation_options: SimulationOptions,
execution_options: SimulationExecutionOptions,
time_rev: bool = False, # deprecated parameter
):
) -> Optional[SimulationResult]:
"""
3D time-domain simulation of wave propagation using k-space pseudospectral method.

Expand Down
11 changes: 6 additions & 5 deletions kwave/kspaceFirstOrderAS.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Union
from typing import Optional, Union

import numpy as np
from numpy.fft import ifftshift

from kwave.data import SimulationResult
from kwave.enums import DiscreteCosine
from kwave.executor import Executor
from kwave.kgrid import kWaveGrid
Expand All @@ -30,7 +31,7 @@ def kspaceFirstOrderASC(
medium: kWaveMedium,
simulation_options: SimulationOptions,
execution_options: SimulationExecutionOptions,
):
) -> Optional[SimulationResult]:
"""
Axisymmetric time-domain simulation of wave propagation using C++ code.

Expand Down Expand Up @@ -89,7 +90,7 @@ def kspaceFirstOrderAS(
medium: kWaveMedium,
simulation_options: SimulationOptions,
execution_options: SimulationExecutionOptions,
):
) -> Optional[SimulationResult]:
"""
Axisymmetric time-domain simulation of wave propagation.

Expand Down Expand Up @@ -158,7 +159,7 @@ def kspaceFirstOrderAS(
if simulation_options.simulation_type is not SimulationType.AXISYMMETRIC:
logging.log(
logging.WARN,
"simulation type is not set to axisymmetric while using kSapceFirstOrderAS. " "Setting simulation type to axisymmetric.",
"simulation type is not set to axisymmetric while using kSapceFirstOrderAS. Setting simulation type to axisymmetric.",
)
simulation_options.simulation_type = SimulationType.AXISYMMETRIC

Expand Down Expand Up @@ -296,7 +297,7 @@ def kspaceFirstOrderAS(

# option to run simulations without the spatial staggered grid is not
# supported for the axisymmetric code
assert options.use_sg, "Optional input " "UseSG" " is not supported for axisymmetric simulations."
assert options.use_sg, "Optional input UseSG is not supported for axisymmetric simulations."

# =========================================================================
# SAVE DATA TO DISK FOR RUNNING SIMULATION EXTERNAL TO MATLAB
Expand Down
8 changes: 0 additions & 8 deletions kwave/options/simulation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ class SimulationOptions(object):
"""
Args:
axisymmetric: Flag that indicates whether axisymmetric simulation is used
cart_interp: Interpolation mode used to extract the pressure when a Cartesian sensor mask is given.
If set to 'nearest' and more than one Cartesian point maps to the same grid point,
duplicated data points are discarded and sensor_data will be returned
with less points than that specified by sensor.mask (default = 'linear').
pml_inside: put the PML inside the grid defined by the user
pml_alpha: Absorption within the perfectly matched layer in Nepers per grid point (default = 2).
save_to_disk: save the input data to a HDF5 file
Expand Down Expand Up @@ -85,7 +81,6 @@ class SimulationOptions(object):
"""

simulation_type: SimulationType = SimulationType.FLUID
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 cart_interp removal is a hard breaking change

Removing cart_interp as a named constructor argument means any existing caller passing SimulationOptions(cart_interp="nearest") will get a TypeError at runtime. The PR description's claim of "non-breaking at call sites" applies only to the return-type change; the option rename is explicitly a breaking change for anyone not on the new API. Documenting a migration path or providing a __post_init__ deprecation path (accepting cart_interp and forwarding it with a warning) would be safer for library consumers.

cart_interp: str = "linear"
pml_inside: bool = True
pml_alpha: float = 2.0
save_to_disk: bool = False
Expand Down Expand Up @@ -203,9 +198,6 @@ def option_factory(kgrid: "kWaveGrid", options: SimulationOptions):
elastic_code: Flag that indicates whether elastic simulation is used
**kwargs: Dictionary that holds following optional simulation properties:

* cart_interp: Interpolation mode used to extract the pressure when a Cartesian sensor mask is given.
If set to 'nearest', duplicated data points are discarded and sensor_data
will be returned with fewer points than specified by sensor.mask (default = 'linear').
Comment on lines -206 to -208
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we document Cartesian interp with the old docstring?

* create_log: Boolean controlling whether the command line output is saved using the diary function
with a date and time stamped filename (default = false).
* data_cast: String input of the data type that variables are cast to before computation.
Expand Down
Loading
Loading