From f108522d2a5ba3133ca508217b1a5efb66c43590 Mon Sep 17 00:00:00 2001 From: tjc0726 Date: Tue, 24 Feb 2026 00:09:26 +0000 Subject: [PATCH] Optimize build_cartesian_ncc_matrix --- dedalus/core/arithmetic.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 64daa530..07539e8b 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -430,18 +430,31 @@ def build_cartesian_ncc_matrix(self, subproblem, ncc_cutoff, max_ncc_terms): Gamma = Gamma.transpose((2, 0, 1)) # Loop over NCC modes shape = (subproblem.field_size(out), subproblem.field_size(arg)) - matrix = sparse.csr_matrix(shape, dtype=self.dtype) subproblem_shape = subproblem.coeff_shape(out.domain) ncc_rank = len(ncc.tensorsig) select_all_comps = tuple(slice(None) for i in range(ncc_rank)) + # Optimization: batch accumulate matrices instead of sequential addition + all_rows = [] + all_cols = [] + all_data = [] if np.any(self._ncc_data): for ncc_mode in np.ndindex(self._ncc_data.shape[ncc_rank:]): ncc_coeffs = self._ncc_data[select_all_comps + ncc_mode] if np.max(np.abs(ncc_coeffs)) > ncc_cutoff: mode_matrix = self.cartesian_mode_matrix(subproblem_shape, ncc.domain, arg.domain, out.domain, ncc_mode) - mode_matrix = sparse.kron(np.dot(Gamma, ncc_coeffs.ravel()), mode_matrix, format='csr') - matrix = matrix + mode_matrix - return matrix + mode_matrix = sparse.kron(np.dot(Gamma, ncc_coeffs.ravel()), mode_matrix, format='coo') + all_rows.append(mode_matrix.row) + all_cols.append(mode_matrix.col) + all_data.append(mode_matrix.data) + # Batch merge all mode matrices + if all_rows: + combined_row = np.concatenate(all_rows) + combined_col = np.concatenate(all_cols) + combined_data = np.concatenate(all_data) + matrix = sparse.coo_matrix((combined_data, (combined_row, combined_col)), shape=shape) + return matrix.tocsr() + else: + return sparse.csr_matrix(shape, dtype=self.dtype) @classmethod def cartesian_mode_matrix(cls, subproblem_shape, ncc_domain, arg_domain, out_domain, ncc_mode):