-
Notifications
You must be signed in to change notification settings - Fork 57
Add dtype parameter to kspaceFirstOrder() (#695) #716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8cad56c
dbfbee5
0414f3d
815dc8b
cd7d4f3
22ec78d
d545952
6920efb
d8537f1
693cb46
87d2a91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | |||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -32,40 +32,55 @@ def _to_cpu(x): | ||||||||||||||||||||||||
| return x.get() if hasattr(x, "get") else x | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def _expand_to_grid(val, grid_shape, xp, name="parameter"): | |||||||||||||||||||||||||
| def _array_sum(arrays): | |||||||||||||||||||||||||
| """Sum arrays, preserving dtype. | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| ``sum(arrays)`` starts from Python ``int 0``; under numpy < 2 (NEP 50) | |||||||||||||||||||||||||
| that promotes float32 inputs to float64. Starting from ``arrays[0]`` | |||||||||||||||||||||||||
| keeps the result's dtype equal to the elements'. | |||||||||||||||||||||||||
| """ | |||||||||||||||||||||||||
| out = arrays[0] | |||||||||||||||||||||||||
| for a in arrays[1:]: | |||||||||||||||||||||||||
| out = out + a | |||||||||||||||||||||||||
| return out | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def _expand_to_grid(val, grid_shape, xp, name="parameter", dtype=float): | |||||||||||||||||||||||||
| if val is None: | |||||||||||||||||||||||||
| raise ValueError(f"Missing required parameter: {name}") | |||||||||||||||||||||||||
| arr = xp.array(val, dtype=float).ravel() | |||||||||||||||||||||||||
| arr = xp.array(val, dtype=dtype).ravel() | |||||||||||||||||||||||||
| grid_size = int(np.prod(grid_shape)) | |||||||||||||||||||||||||
| if arr.size == 1: | |||||||||||||||||||||||||
| return xp.full(grid_shape, float(arr[0]), dtype=float) | |||||||||||||||||||||||||
| return xp.full(grid_shape, arr[0], dtype=dtype) | |||||||||||||||||||||||||
| if arr.size == grid_size: | |||||||||||||||||||||||||
| return arr.reshape(grid_shape) | |||||||||||||||||||||||||
| raise ValueError(f"{name} size {arr.size} incompatible with grid size {grid_size}") | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def _build_source_op(mask_raw, signal_raw, mode, scale, *, xp, grid_shape, grid_size, source_kappa, diff_fn): | |||||||||||||||||||||||||
| def _build_source_op(mask_raw, signal_raw, mode, scale, *, xp, grid_shape, grid_size, source_kappa, diff_fn, dtype=float): | |||||||||||||||||||||||||
| """Build a source injection operator for one field variable. | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| Returns a callable (t, field) → field that injects scaled source values. | |||||||||||||||||||||||||
| ``dtype`` controls the precision of the source signal buffer (matches the | |||||||||||||||||||||||||
| field precision set by ``data_cast``). | |||||||||||||||||||||||||
| """ | |||||||||||||||||||||||||
| mask = xp.array(mask_raw, dtype=bool).ravel() | |||||||||||||||||||||||||
| if mask.size == 1: | |||||||||||||||||||||||||
| mask = xp.full(grid_shape, bool(mask[0]), dtype=bool).ravel() | |||||||||||||||||||||||||
| n_src = int(xp.sum(mask)) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| signal_arr = xp.array(signal_raw, dtype=float) | |||||||||||||||||||||||||
| signal_arr = xp.array(signal_raw, dtype=dtype) | |||||||||||||||||||||||||
| if signal_arr.ndim == 1: | |||||||||||||||||||||||||
| signal = signal_arr.reshape(1, -1) | |||||||||||||||||||||||||
| else: | |||||||||||||||||||||||||
| signal = signal_arr.reshape(-1, signal_arr.shape[-1]) if signal_arr.ndim > 2 else signal_arr | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| scaled = signal * xp.atleast_1d(xp.asarray(scale))[:, None] | |||||||||||||||||||||||||
| scaled = signal * xp.atleast_1d(xp.asarray(scale, dtype=dtype))[:, None] | |||||||||||||||||||||||||
| signal_len = scaled.shape[1] | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def get_val(t): | |||||||||||||||||||||||||
| if scaled.shape[0] == 1: | |||||||||||||||||||||||||
| return xp.full(n_src, float(scaled[0, t])) | |||||||||||||||||||||||||
| return xp.full(n_src, scaled[0, t], dtype=dtype) | |||||||||||||||||||||||||
| return scaled[:, t] | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def dirichlet(t, field): | |||||||||||||||||||||||||
|
|
@@ -76,7 +91,7 @@ def dirichlet(t, field): | ||||||||||||||||||||||||
| return flat.reshape(grid_shape) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # Pre-allocate buffer to avoid per-step allocation | |||||||||||||||||||||||||
| _src_buf = xp.zeros(grid_size, dtype=float) | |||||||||||||||||||||||||
| _src_buf = xp.zeros(grid_size, dtype=dtype) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def additive_kspace(t, field): | |||||||||||||||||||||||||
| if t >= signal_len: | |||||||||||||||||||||||||
|
|
@@ -131,6 +146,7 @@ def __init__( | ||||||||||||||||||||||||
| pml_size=None, | |||||||||||||||||||||||||
| pml_alpha=None, | |||||||||||||||||||||||||
| quiet=False, | |||||||||||||||||||||||||
| dtype=None, | |||||||||||||||||||||||||
| ): | |||||||||||||||||||||||||
| self.kgrid = kgrid | |||||||||||||||||||||||||
| self.medium = medium | |||||||||||||||||||||||||
|
|
@@ -142,6 +158,25 @@ def __init__( | ||||||||||||||||||||||||
| self.quiet = quiet | |||||||||||||||||||||||||
| self._pml_size_override = pml_size | |||||||||||||||||||||||||
| self._pml_alpha_override = pml_alpha | |||||||||||||||||||||||||
| # Compute precision for state arrays. ``None`` defaults to float64 (matches | |||||||||||||||||||||||||
| # MATLAB k-Wave). Only float32 / float64 are validated by the solver. | |||||||||||||||||||||||||
| if dtype is None: | |||||||||||||||||||||||||
| self._dtype = np.float64 | |||||||||||||||||||||||||
| elif dtype in (np.float32, np.float64): | |||||||||||||||||||||||||
| self._dtype = dtype | |||||||||||||||||||||||||
| else: | |||||||||||||||||||||||||
| try: | |||||||||||||||||||||||||
| resolved = np.dtype(dtype).type | |||||||||||||||||||||||||
| except TypeError as e: | |||||||||||||||||||||||||
| raise ValueError(f"dtype must be np.float32 or np.float64 (or string equivalent), got {dtype!r}") from e | |||||||||||||||||||||||||
| if resolved not in (np.float32, np.float64): | |||||||||||||||||||||||||
| raise ValueError(f"dtype must resolve to float32 or float64, got {resolved.__name__}") | |||||||||||||||||||||||||
| self._dtype = resolved | |||||||||||||||||||||||||
| # Companion complex dtype for FFT outputs. numpy<2 (np.fft) always upcasts | |||||||||||||||||||||||||
| # to complex128 regardless of input precision; we cast back so the rest of | |||||||||||||||||||||||||
| # the pipeline stays in self._dtype. Harmless on numpy 2+ and on cupy | |||||||||||||||||||||||||
| # (which already respects input precision). | |||||||||||||||||||||||||
| self._complex_dtype = np.complex64 if self._dtype is np.float32 else np.complex128 | |||||||||||||||||||||||||
| # kWaveGrid doesn't have pml_size_x attrs; warn if PML will silently be disabled | |||||||||||||||||||||||||
| if pml_size is None: | |||||||||||||||||||||||||
| from kwave.kgrid import kWaveGrid as _KWG | |||||||||||||||||||||||||
|
|
@@ -190,9 +225,9 @@ def setup(self): | ||||||||||||||||||||||||
| self.Nt = int(self.kgrid.Nt) | |||||||||||||||||||||||||
| self.dt = float(self.kgrid.dt) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| self.c0 = _expand_to_grid(self.medium.sound_speed, self.grid_shape, xp, "sound_speed") | |||||||||||||||||||||||||
| self.c0 = _expand_to_grid(self.medium.sound_speed, self.grid_shape, xp, "sound_speed", dtype=self._dtype) | |||||||||||||||||||||||||
| density = getattr(self.medium, "density", None) | |||||||||||||||||||||||||
| self.rho0 = _expand_to_grid(density if density is not None else 1000.0, self.grid_shape, xp, "density") | |||||||||||||||||||||||||
| self.rho0 = _expand_to_grid(density if density is not None else 1000.0, self.grid_shape, xp, "density", dtype=self._dtype) | |||||||||||||||||||||||||
| self.c_ref = float(xp.max(self.c0)) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| self._setup_sensor_mask() | |||||||||||||||||||||||||
|
|
@@ -356,26 +391,32 @@ def _setup_pml(self): | ||||||||||||||||||||||||
| if pml_size == 0 or pml_alpha == 0: | |||||||||||||||||||||||||
| shape = [1] * self.ndim | |||||||||||||||||||||||||
| shape[axis] = N | |||||||||||||||||||||||||
| self.pml_list.append(xp.ones(shape, dtype=float)) | |||||||||||||||||||||||||
| self.pml_sg_list.append(xp.ones(shape, dtype=float)) | |||||||||||||||||||||||||
| self.pml_list.append(xp.ones(shape, dtype=self._dtype)) | |||||||||||||||||||||||||
| self.pml_sg_list.append(xp.ones(shape, dtype=self._dtype)) | |||||||||||||||||||||||||
| else: | |||||||||||||||||||||||||
| # dimension=2 gives shape (1, N) which we reshape for broadcasting | |||||||||||||||||||||||||
| # dimension=2 gives shape (1, N) which we reshape for broadcasting. | |||||||||||||||||||||||||
| # get_pml returns float64; cast so the per-step PML multiply doesn't | |||||||||||||||||||||||||
| # upcast self.p / self.u (would silently break dtype='single'). | |||||||||||||||||||||||||
| pml = get_pml(N, dx, self.dt, self.c_ref, pml_size, pml_alpha, staggered=False, dimension=2, xp=xp) | |||||||||||||||||||||||||
| pml_sg = get_pml(N, dx, self.dt, self.c_ref, pml_size, pml_alpha, staggered=True, dimension=2, xp=xp) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| shape = [1] * self.ndim | |||||||||||||||||||||||||
| shape[axis] = N | |||||||||||||||||||||||||
| self.pml_list.append(pml.flatten().reshape(shape)) | |||||||||||||||||||||||||
| self.pml_sg_list.append(pml_sg.flatten().reshape(shape)) | |||||||||||||||||||||||||
| self.pml_list.append(pml.flatten().reshape(shape).astype(self._dtype)) | |||||||||||||||||||||||||
| self.pml_sg_list.append(pml_sg.flatten().reshape(shape).astype(self._dtype)) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def _setup_kspace_operators(self): | |||||||||||||||||||||||||
| """Build k-space gradient/divergence operators for each dimension.""" | |||||||||||||||||||||||||
| xp = self.xp | |||||||||||||||||||||||||
| self.k_list = [] | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # First pass: build k-vectors for each dimension | |||||||||||||||||||||||||
| # First pass: build k-vectors for each dimension. | |||||||||||||||||||||||||
| # ``fftfreq`` returns float64 by default; cast so the k-space operators | |||||||||||||||||||||||||
| # (kappa, op_grad_list, op_div_list, _k_mag) match self._dtype. Without | |||||||||||||||||||||||||
| # this cast, _diff's FFT round-trip with a float64 op upcasts the | |||||||||||||||||||||||||
| # float32 field back to float64 -- silently breaking dtype='single'. | |||||||||||||||||||||||||
| for axis, (N, dx) in enumerate(zip(self.grid_shape, self.spacing)): | |||||||||||||||||||||||||
| k = 2 * np.pi * xp.fft.fftfreq(N, d=dx) | |||||||||||||||||||||||||
| k = (2 * np.pi * xp.fft.fftfreq(N, d=dx)).astype(self._dtype) | |||||||||||||||||||||||||
| shape = [1] * self.ndim | |||||||||||||||||||||||||
| shape[axis] = N | |||||||||||||||||||||||||
| self.k_list.append(k.reshape(shape)) | |||||||||||||||||||||||||
|
|
@@ -424,7 +465,7 @@ def _alpha_neper_and_power(self): | ||||||||||||||||||||||||
| if not _is_enabled(getattr(self.medium, "alpha_coeff", 0)): | |||||||||||||||||||||||||
| return None, None | |||||||||||||||||||||||||
| alpha_power = float(self.xp.array(getattr(self.medium, "alpha_power", 1.5)).flatten()[0]) | |||||||||||||||||||||||||
| alpha_coeff = _expand_to_grid(self.medium.alpha_coeff, self.grid_shape, self.xp, "alpha_coeff") | |||||||||||||||||||||||||
| alpha_coeff = _expand_to_grid(self.medium.alpha_coeff, self.grid_shape, self.xp, "alpha_coeff", dtype=self._dtype) | |||||||||||||||||||||||||
| return db2neper(alpha_coeff, alpha_power), alpha_power | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def _init_absorption(self, alpha_np, alpha_power): | |||||||||||||||||||||||||
|
|
@@ -504,9 +545,9 @@ def _init_nonlinearity(self): | ||||||||||||||||||||||||
| self._nonlinearity = lambda rho: 0 | |||||||||||||||||||||||||
| self._nl_factor = lambda rho_split: 1.0 | |||||||||||||||||||||||||
| else: | |||||||||||||||||||||||||
| self.BonA = _expand_to_grid(BonA_raw, self.grid_shape, self.xp, "BonA") | |||||||||||||||||||||||||
| self.BonA = _expand_to_grid(BonA_raw, self.grid_shape, self.xp, "BonA", dtype=self._dtype) | |||||||||||||||||||||||||
| self._nonlinearity = lambda rho: self.BonA * rho**2 / (2 * self.rho0) | |||||||||||||||||||||||||
| self._nl_factor = lambda rho_split: (2 * sum(rho_split) + self.rho0) / self.rho0 | |||||||||||||||||||||||||
| self._nl_factor = lambda rho_split: (2 * _array_sum(rho_split) + self.rho0) / self.rho0 | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def _setup_source_operators(self): | |||||||||||||||||||||||||
| """Build time-varying source injection operators. | |||||||||||||||||||||||||
|
|
@@ -560,6 +601,7 @@ def build_op(mask_raw, signal_raw, mode, scale): | ||||||||||||||||||||||||
| grid_size=grid_size, | |||||||||||||||||||||||||
| source_kappa=self.source_kappa, | |||||||||||||||||||||||||
| diff_fn=self._diff, | |||||||||||||||||||||||||
| dtype=self._dtype, | |||||||||||||||||||||||||
| ) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # --- Pressure source (per-axis spacing for non-isotropic grids) --- | |||||||||||||||||||||||||
|
|
@@ -606,10 +648,10 @@ def _setup_fields(self): | ||||||||||||||||||||||||
| """Initialize pressure, velocity, and density fields.""" | |||||||||||||||||||||||||
| xp = self.xp | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| self.p = xp.zeros(self.grid_shape, dtype=float) | |||||||||||||||||||||||||
| self.u = [xp.zeros(self.grid_shape, dtype=float) for _ in range(self.ndim)] | |||||||||||||||||||||||||
| self.p = xp.zeros(self.grid_shape, dtype=self._dtype) | |||||||||||||||||||||||||
| self.u = [xp.zeros(self.grid_shape, dtype=self._dtype) for _ in range(self.ndim)] | |||||||||||||||||||||||||
| # Split density per dimension enables independent PML absorption in each direction | |||||||||||||||||||||||||
| self.rho_split = [xp.zeros(self.grid_shape, dtype=float) for _ in range(self.ndim)] | |||||||||||||||||||||||||
| self.rho_split = [xp.zeros(self.grid_shape, dtype=self._dtype) for _ in range(self.ndim)] | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| if self.use_sg: | |||||||||||||||||||||||||
| self.rho0_staggered = [self._stagger(self.rho0, axis) for axis in range(self.ndim)] | |||||||||||||||||||||||||
|
|
@@ -623,25 +665,25 @@ def _setup_fields(self): | ||||||||||||||||||||||||
| # Sensor data storage (sized based on record_start_index) | |||||||||||||||||||||||||
| self.sensor_data = {} | |||||||||||||||||||||||||
| if "p" in self.record: | |||||||||||||||||||||||||
| self.sensor_data["p"] = xp.zeros((self.n_sensor_points, self.num_recorded_time_points), dtype=float) | |||||||||||||||||||||||||
| self.sensor_data["p"] = xp.zeros((self.n_sensor_points, self.num_recorded_time_points), dtype=self._dtype) | |||||||||||||||||||||||||
| for a in "xyz"[: self.ndim]: | |||||||||||||||||||||||||
| for suffix in ("", "_staggered"): | |||||||||||||||||||||||||
| v = f"u{a}{suffix}" | |||||||||||||||||||||||||
| if v in self.record: | |||||||||||||||||||||||||
| self.sensor_data[v] = xp.zeros((self.n_sensor_points, self.num_recorded_time_points), dtype=float) | |||||||||||||||||||||||||
| self.sensor_data[v] = xp.zeros((self.n_sensor_points, self.num_recorded_time_points), dtype=self._dtype) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | |||||||||||||||||||||||||
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)] | |||||||||||||||||||||||||
|
Comment on lines
675
to
676
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The same Python-literal promotion issue as
Suggested change
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verified empirically on numpy 1.26.4: Additionally, the prior commit (22ec78d) already added Added a regression test (d545952) that records @greptile-apps re-review
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The empirical result doesn't match what NEP 50's own migration table documents. The NEP 50 spec explicitly shows the old (numpy < 2) vs new (numpy ≥ 2 / NEP 50) behaviour:
Source: https://numpy.org/neps/nep-0050-scalar-promotion.html Python complex literals are not weak types in numpy < 2 — they behave like It's possible the empirical test ran against a numpy 2.x environment, or that The fix remains: cast the list comprehension result to
Suggested change
Tip: You can give me more context with .greptile/rules.md so I get this right next time. |
|||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # Initial pressure source (p0) | |||||||||||||||||||||||||
| p0_raw = getattr(self.source, "p0", 0) | |||||||||||||||||||||||||
| if _is_enabled(p0_raw): | |||||||||||||||||||||||||
| p0 = _expand_to_grid(p0_raw, self.grid_shape, xp, "p0") | |||||||||||||||||||||||||
| p0 = _expand_to_grid(p0_raw, self.grid_shape, xp, "p0", dtype=self._dtype) | |||||||||||||||||||||||||
| if self.smooth_p0 and self.ndim >= 2: | |||||||||||||||||||||||||
| from kwave.utils.filters import smooth | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # smooth() is order-agnostic (uses FFT on shape) | |||||||||||||||||||||||||
| p0 = xp.asarray(smooth(_to_cpu(p0), restore_max=True)) | |||||||||||||||||||||||||
| p0 = xp.asarray(smooth(_to_cpu(p0), restore_max=True), dtype=self._dtype) | |||||||||||||||||||||||||
| self._p0_initial = p0 | |||||||||||||||||||||||||
| else: | |||||||||||||||||||||||||
| self._p0_initial = None | |||||||||||||||||||||||||
|
|
@@ -657,16 +699,16 @@ def step(self): | ||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # Momentum equation: du_i/dt = -grad_i(p)/rho, with PML | |||||||||||||||||||||||||
| # Share forward FFT of p across all gradient axes | |||||||||||||||||||||||||
| P = xp.fft.fftn(self.p) | |||||||||||||||||||||||||
| P = xp.fft.fftn(self.p).astype(self._complex_dtype, copy=False) | |||||||||||||||||||||||||
| for i in range(self.ndim): | |||||||||||||||||||||||||
| pml_sg = self.pml_sg_list[i] | |||||||||||||||||||||||||
| grad_p_i = xp.real(xp.fft.ifftn(self.op_grad_list[i] * P)) | |||||||||||||||||||||||||
| grad_p_i = xp.real(xp.fft.ifftn(self.op_grad_list[i] * P)).astype(self._dtype, copy=False) | |||||||||||||||||||||||||
| self.u[i] = pml_sg * (pml_sg * self.u[i] - self.dt_over_rho0[i] * grad_p_i) | |||||||||||||||||||||||||
| self.u[i] = self._source_u_ops[i](self.t, self.u[i]) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # Mass conservation: drho_i/dt = -rho0 * div_i(u_i) * nl_factor, with PML | |||||||||||||||||||||||||
| nl_factor = self._nl_factor(self.rho_split) | |||||||||||||||||||||||||
| div_u_total = xp.zeros(self.grid_shape, dtype=float) | |||||||||||||||||||||||||
| div_u_total = xp.zeros(self.grid_shape, dtype=self._dtype) | |||||||||||||||||||||||||
| for i in range(self.ndim): | |||||||||||||||||||||||||
| pml = self.pml_list[i] | |||||||||||||||||||||||||
| div_u_i = self._diff(self.u[i], self.op_div_list[i]) | |||||||||||||||||||||||||
|
|
@@ -675,8 +717,10 @@ def step(self): | ||||||||||||||||||||||||
| self.rho_split[i] = self._source_p_ops[i](self.t, self.rho_split[i]) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # Equation of state: p = c0^2 * (rho + absorption - dispersion + nonlinearity) | |||||||||||||||||||||||||
| rho_total = sum(self.rho_split) | |||||||||||||||||||||||||
| self.p = self.c0_sq * (rho_total + self._absorption(div_u_total) - self._dispersion(rho_total) + self._nonlinearity(rho_total)) | |||||||||||||||||||||||||
| rho_total = _array_sum(self.rho_split) | |||||||||||||||||||||||||
| self.p = self.c0_sq * ( | |||||||||||||||||||||||||
| rho_total + self._absorption(div_u_total) - self._dispersion(rho_total) + self._nonlinearity(rho_total) | |||||||||||||||||||||||||
| ) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| # At t=0, override equation of state with p0; set u(dt/2) for leapfrog. | |||||||||||||||||||||||||
| # MATLAB convention (kspaceFirstOrder2D.m line 920): u(dt/2) = +dt/(2*rho) * grad(p0) | |||||||||||||||||||||||||
|
|
@@ -695,7 +739,7 @@ def step(self): | ||||||||||||||||||||||||
| self.sensor_data["p"][:, file_index] = self._extract(self.p) | |||||||||||||||||||||||||
| for i, a in enumerate("xyz"[: self.ndim]): | |||||||||||||||||||||||||
| if f"u{a}" in self.sensor_data: # non-staggered (collocated with pressure) | |||||||||||||||||||||||||
| shifted = xp.real(xp.fft.ifftn(self.unstagger_ops[i] * xp.fft.fftn(self.u[i]))) | |||||||||||||||||||||||||
| shifted = xp.real(xp.fft.ifftn(self.unstagger_ops[i] * xp.fft.fftn(self.u[i]))).astype(self._dtype, copy=False) | |||||||||||||||||||||||||
| self.sensor_data[f"u{a}"][:, file_index] = self._extract(shifted) | |||||||||||||||||||||||||
| if f"u{a}_staggered" in self.sensor_data: # raw staggered grid | |||||||||||||||||||||||||
| self.sensor_data[f"u{a}_staggered"][:, file_index] = self._extract(self.u[i]) | |||||||||||||||||||||||||
|
|
@@ -749,7 +793,7 @@ def _diff(self, f, op): | ||||||||||||||||||||||||
| if op is None: | |||||||||||||||||||||||||
| return f | |||||||||||||||||||||||||
| xp = self.xp | |||||||||||||||||||||||||
| return xp.real(xp.fft.ifftn(op * xp.fft.fftn(f))) | |||||||||||||||||||||||||
| return xp.real(xp.fft.ifftn(op * xp.fft.fftn(f))).astype(self._dtype, copy=False) | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
| def _stagger(self, arr, axis): | |||||||||||||||||||||||||
| """Compute staggered grid values (average neighbors along axis).""" | |||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int 2literal re-introduces dtype promotion on numpy < 2_array_sumwas added to preventsum(rho_split)starting from Pythonint 0, but2 * _array_sum(rho_split)still multiplies by a Pythonint— which numpy < 2 (pre-NEP 50) treats asnp.int64(a strong type).np.result_type(np.int64, np.float32)→np.float64, sonl_factorisfloat64when BonA is enabled on numpy 1.x, and propagates throughrho_split[i]intoself.p→p_final. The sameint 2divisor in_nonlinearity(2 * self.rho0) causes the same promotion for the equation-of-state nonlinear term. On numpy >= 2 (NEP 50 weak scalars) the tests pass, but on numpy 1.26 they will silently produce float64p_finalinstead of float32.