From 8bd326291bcae6a33ee9f4fce83e11ee867a1164 Mon Sep 17 00:00:00 2001 From: Paolo Campeti <36236108+pcampeti@users.noreply.github.com> Date: Sat, 29 Nov 2025 14:56:16 +0100 Subject: [PATCH 1/8] Align scattering convolutions with FoCUS orientation handling --- STL_main/STL_2D_Kernel_Torch.py | 419 ++++++++++++++++++++++++++++---- 1 file changed, 366 insertions(+), 53 deletions(-) diff --git a/STL_main/STL_2D_Kernel_Torch.py b/STL_main/STL_2D_Kernel_Torch.py index 4e6dee8..7e9a333 100644 --- a/STL_main/STL_2D_Kernel_Torch.py +++ b/STL_main/STL_2D_Kernel_Torch.py @@ -20,6 +20,38 @@ import torch import torch.nn.functional as F +def _conv2d_same_symmetric(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + """ + 2D convolution with "same" output size and symmetric padding. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape [..., C, Nx, Ny]. + w : torch.Tensor + Kernel tensor of shape [O_c, C, wx, wy]. + + Returns + ------- + torch.Tensor + Convolved tensor with shape [..., O_c, Nx, Ny]. + """ + + *leading_dims, C, Nx, Ny = x.shape + O_c, _, wx, wy = w.shape + + B = int(torch.prod(torch.tensor(leading_dims))) if leading_dims else 1 + x4d = x.reshape(B, C, Nx, Ny) + + pad_x = wx // 2 + pad_y = wy // 2 + + x_padded = F.pad(x4d, (pad_y, pad_y, pad_x, pad_x), mode="reflect") + y = F.conv2d(x_padded, w) + + return y.reshape(*leading_dims, O_c, Nx, Ny) + + def _conv2d_circular(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: """ Backend-style 2D convolution mirroring FoCUS/BkTorch strategy. @@ -27,9 +59,9 @@ def _conv2d_circular(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: Parameters ---------- x : torch.Tensor - Input tensor of shape [..., Nx, Ny]. + Input tensor of shape [..., C, Nx, Ny]. w : torch.Tensor - Kernel tensor of shape [O_c, wx, wy]. + Kernel tensor of shape [O_c, C, wx, wy]. Returns ------- @@ -37,22 +69,38 @@ def _conv2d_circular(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: Convolved tensor with shape [..., O_c, Nx, Ny]. """ - *leading_dims, Nx, Ny = x.shape - O_c, wx, wy = w.shape + *leading_dims, C, Nx, Ny = x.shape + O_c, _, wx, wy = w.shape B = int(torch.prod(torch.tensor(leading_dims))) if leading_dims else 1 - x4d = x.reshape(B, 1, Nx, Ny) + x4d = x.reshape(B, C, Nx, Ny) - weight = w[:, None, :, :] pad_x = wx // 2 pad_y = wy // 2 x_padded = F.pad(x4d, (pad_y, pad_y, pad_x, pad_x), mode="circular") - y = F.conv2d(x_padded, weight) + y = F.conv2d(x_padded, w) return y.reshape(*leading_dims, O_c, Nx, Ny) +def _complex_conv2d_same_symmetric(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + """Complex-aware wrapper around ``_conv2d_same_symmetric``.""" + + xr = torch.real(x) if torch.is_complex(x) else x + xi = torch.imag(x) if torch.is_complex(x) else torch.zeros_like(xr) + + wr = torch.real(w) if torch.is_complex(w) else w + wi = torch.imag(w) if torch.is_complex(w) else torch.zeros_like(wr) + + real_part = _conv2d_same_symmetric(xr, wr) - _conv2d_same_symmetric(xi, wi) + imag_part = _conv2d_same_symmetric(xr, wi) + _conv2d_same_symmetric(xi, wr) + + if torch.is_complex(x) or torch.is_complex(w): + return torch.complex(real_part, imag_part) + return real_part + + def _complex_conv2d_circular(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: """Complex-aware wrapper around ``_conv2d_circular``.""" @@ -662,46 +710,306 @@ def __init__(self, kernel_size: int, L: int, J: int, self.J = J self.device = torch.device(device) self.dtype = dtype - - self.kernel = self._wavelet_kernel(kernel_size,L) + + self.kernel, real_kernel, imag_kernel, smooth_kernel = self._wavelet_kernel( + kernel_size, L + ) self.WType='simple' - def _wavelet_kernel(self,kernel_size: int,n_orientation: int,sigma=1): - """Create a 2D Wavelet kernel.""" - # coords = torch.arange(kernel_size, device=self.device, dtype=self.dtype) - (kernel_size - 1) / 2.0 - # yy, xx = torch.meshgrid(coords, coords, indexing="ij") - # mother_kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))[None,:,:] - # angles=torch.arange(n_orientation, device=self.device, dtype=self.dtype)/n_orientation*torch.pi - # angles_proj=torch.pi*(xx[None,...]*torch.cos(angles[:,None,None])+yy[None,...]*torch.sin(angles[:,None,None])) - # kernel = torch.complex(torch.cos(angles_proj)*mother_kernel,torch.sin(angles_proj)*mother_kernel) - # kernel = kernel - torch.mean(kernel,dim=(1,2))[:,None,None] - # kernel = kernel / torch.sqrt(torch.sum(kernel**2, dim=(1,2)))[:,None,None] - # return kernel.reshape(1,n_orientation,kernel_size,kernel_size) - - ###Morlay wavelet - coords = torch.arange(kernel_size, device=self.device, dtype=self.dtype) - (kernel_size - 1) / 2.0 - yy, xx = torch.meshgrid(coords, coords, indexing="ij") - - # Gaussian envelope - gaussian_envelope = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) - - # Orientations - angles = torch.arange(n_orientation, device=self.device, dtype=self.dtype) / n_orientation * torch.pi - - # Morlet wavelet: exp(i*k0*x_rot) * gaussian_envelope - # x_rot is the coordinate along the orientation direction - x_rot = xx[None, :, :] * torch.cos(angles[:, None, None]) + yy[None, :, :] * torch.sin(angles[:, None, None]) - - # Complex Morlet wavelet - kernel = torch.exp(1j * 0.75 * np.pi * x_rot ) * gaussian_envelope[None, :, :] - - # Remove DC component (admissibility condition) - kernel = kernel - torch.mean(kernel, dim=(1, 2))[:, None, None] - - # L2 normalization - kernel = kernel / torch.sqrt(torch.sum(torch.abs(kernel)**2, dim=(1, 2)))[:, None, None] - - return kernel.reshape(1, n_orientation, kernel_size, kernel_size) + def _wavelet_kernel(self, kernel_size: int, n_orientation: int): + """FoCUS CNNV1 planar wavelet construction (cos/sin over Gaussian).""" + + KERNELSZ = kernel_size + NORIENT = n_orientation + LAMBDA = 1.0 + + # Allocate real/imag components + wwc = np.zeros([NORIENT, KERNELSZ * KERNELSZ], dtype=np.float64) + wws = np.zeros_like(wwc) + + x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape( + KERNELSZ, KERNELSZ + ) + y = x.T + + if NORIENT == 1: + xx = (3.0 / float(KERNELSZ)) * LAMBDA * x + yy = (3.0 / float(KERNELSZ)) * LAMBDA * y + + if KERNELSZ == 5: + w_smooth = np.exp(-(xx**2 + yy**2)) + tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp( + -0.5 * (xx**2 + yy**2) + ) + else: + w_smooth = np.exp(-0.5 * (xx**2 + yy**2)) + tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp( + -0.5 * (xx**2 + yy**2) + ) + + wwc[0] = tmp.flatten() - tmp.mean() + wws[0] = np.zeros_like(wwc[0]) + sigma = np.sqrt((wwc[:, 0] ** 2).mean()) + wwc[0] /= sigma + wws[0] /= sigma + + w_smooth = w_smooth.flatten() + else: + for i in range(NORIENT): + a = (NORIENT - 1 - i) / float(NORIENT) * np.pi + if KERNELSZ < 5: + xx = (3.0 / float(KERNELSZ)) * LAMBDA * ( + x * np.cos(a) + y * np.sin(a) + ) + yy = (3.0 / float(KERNELSZ)) * LAMBDA * ( + x * np.sin(a) - y * np.cos(a) + ) + else: + xx = (3.0 / 5.0) * LAMBDA * (x * np.cos(a) + y * np.sin(a)) + yy = (3.0 / 5.0) * LAMBDA * (x * np.sin(a) - y * np.cos(a)) + + if KERNELSZ == 5: + w_smooth = np.exp(-2 * ((3.0 / float(KERNELSZ) * xx) ** 2 + (3.0 / float(KERNELSZ) * yy) ** 2)) + else: + w_smooth = np.exp(-0.5 * (xx**2 + yy**2)) + + tmp1 = np.cos(yy * np.pi) * w_smooth + tmp2 = np.sin(yy * np.pi) * w_smooth + + wwc[i] = tmp1.flatten() - tmp1.mean() + wws[i] = tmp2.flatten() - tmp2.mean() + sigma = np.mean(w_smooth) + wwc[i] /= sigma + wws[i] /= sigma + + w_smooth = w_smooth.flatten() + + w_smooth = w_smooth / w_smooth.sum() + + # Real/imaginary kernels for the primary convolution path (Cin=1) + real_kernel = torch.tensor( + wwc.reshape(NORIENT, 1, KERNELSZ, KERNELSZ), device=self.device, dtype=self.dtype + ) + imag_kernel = torch.tensor( + wws.reshape(NORIENT, 1, KERNELSZ, KERNELSZ), device=self.device, dtype=self.dtype + ) + + # Low-pass smoothing window (depthwise) + smooth_kernel = torch.tensor( + w_smooth.reshape(1, 1, KERNELSZ, KERNELSZ), device=self.device, dtype=self.dtype + ) + + # Orientation-expanded kernels for the second order (Cin=NORIENT, Cout=NORIENT*NORIENT) + def doorientw(x: np.ndarray) -> np.ndarray: + y = np.zeros( + [NORIENT * NORIENT, NORIENT, KERNELSZ, KERNELSZ], dtype=self.dtype + ) + for k in range(NORIENT): + start = k * NORIENT + y[start : start + NORIENT, k, :, :] = x.reshape(NORIENT, KERNELSZ, KERNELSZ) + return y + + orient_real = torch.tensor(doorientw(wwc), device=self.device, dtype=self.dtype) + orient_imag = torch.tensor(doorientw(wws), device=self.device, dtype=self.dtype) + + # Complex kernel packed for convenience + kernel = torch.complex(real_kernel, imag_kernel) + + # Keep both first-order and oriented kernels + self.ww_RealT = [None, real_kernel, orient_real] + self.ww_ImagT = [None, imag_kernel, orient_imag] + self.ww_SmoothT = [None, smooth_kernel] + + return kernel, real_kernel, imag_kernel, smooth_kernel + + def _bk_resize_image(self, im: torch.Tensor, noutx: int, nouty: int) -> torch.Tensor: + """Torch bilinear resize mirroring FoCUS.backend.bk_resize_image.""" + *leading, hx, hy = im.shape + flat = im.reshape(-1, 1, hx, hy) + resized = F.interpolate(flat, size=(noutx, nouty), mode="bilinear", align_corners=False) + return resized.reshape(*leading, noutx, nouty) + + def up_grade(self, im: torch.Tensor, nout: int, axis: int = -1, nouty: int = None) -> torch.Tensor: + if nouty is None: + nouty = nout + return self._bk_resize_image(im, nout, nouty) + + def convol(self, in_image: torch.Tensor, use_oriented: bool = False) -> torch.Tensor: + """FoCUS-like convolution with symmetric padding and complex kernels.""" + + image = in_image.to(dtype=self.kernel.dtype, device=self.kernel.device) + ishape = list(image.shape) + if len(ishape) < 2: + raise ValueError("Use of 2D scat with data that has less than 2D") + + # Ensure channel dimension is present + if image.dim() == 2: + image = image.unsqueeze(0).unsqueeze(0) + elif image.dim() == 3: + image = image.unsqueeze(1) + + *leading, C, npix, npiy = image.shape + ndata = int(np.prod(leading)) if leading else 1 + tim = image.reshape(ndata, C, npix, npiy) + + kernel_r = self.ww_RealT[2] if use_oriented else self.ww_RealT[1] + kernel_i = self.ww_ImagT[2] if use_oriented else self.ww_ImagT[1] + + if torch.is_complex(tim): + rr1 = _conv2d_same_symmetric(torch.real(tim), kernel_r) + ii1 = _conv2d_same_symmetric(torch.real(tim), kernel_i) + rr2 = _conv2d_same_symmetric(torch.imag(tim), kernel_r) + ii2 = _conv2d_same_symmetric(torch.imag(tim), kernel_i) + res = torch.complex(rr1 - ii2, ii1 + rr2) + else: + rr = _conv2d_same_symmetric(tim, kernel_r) + ii = _conv2d_same_symmetric(tim, kernel_i) + res = torch.complex(rr, ii) + + return res.reshape(*leading, kernel_r.shape[0], npix, npiy) + + def smooth(self, in_image: torch.Tensor) -> torch.Tensor: + image = in_image.to(dtype=self.kernel.dtype, device=self.kernel.device) + ishape = list(image.shape) + if len(ishape) < 2: + raise ValueError("Use of 2D scat with data that has less than 2D") + + npix = ishape[-2] + npiy = ishape[-1] + ndata = int(np.prod(ishape[:-2])) if len(ishape) > 2 else 1 + + tim = image.reshape(ndata, 1, npix, npiy) + + if torch.is_complex(tim): + rr = _conv2d_same_symmetric(torch.real(tim), self.ww_SmoothT[1]) + ii = _conv2d_same_symmetric(torch.imag(tim), self.ww_SmoothT[1]) + res = torch.complex(rr, ii) + else: + res = _conv2d_same_symmetric(tim, self.ww_SmoothT[1]) + + return res.reshape(*ishape[:-2], npix, npiy) + + def ud_grade_2(self, im: torch.Tensor) -> torch.Tensor: + ishape = list(im.shape) + if len(ishape) < 2: + raise ValueError("Use of 2D scat with data that has less than 2D") + npix = ishape[-2] + npiy = ishape[-1] + if npix % 2 != 0 or npiy % 2 != 0: + raise ValueError("Downsampling requires even spatial dimensions") + + ndata = 1 + for k in range(len(im.shape) - 2): + ndata *= ishape[k] + + tim = im.reshape(ndata, npix, npiy, 1).permute(0, 3, 1, 2) + res = F.avg_pool2d(tim, kernel_size=2, stride=2) + res = res.permute(0, 2, 3, 1) + return res.reshape(ishape[0:-2] + [npix // 2, npiy // 2]) + + def scattering(self, image1: torch.Tensor): + """ + Compute scattering coefficients (S0, S1, S2, S2L) following FoCUS CNNV1 + planar backend. Masking and normalization are intentionally omitted. + """ + + if image1.dim() == 2: + I1 = image1.unsqueeze(0) + else: + I1 = image1 + + # Add explicit channel dimension for convolutions + if I1.dim() == 3: + I1 = I1.unsqueeze(1) + + im_shape = I1.shape + nside = min(im_shape[-2], im_shape[-1]) + jmax = int(math.log(nside - self.KERNELSZ) / math.log(2)) + + if self.KERNELSZ > 3: + if self.KERNELSZ == 5: + l_image1 = self.up_grade(I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2) + else: + l_image1 = self.up_grade(I1, I1.shape[-2] * 4, nouty=I1.shape[-1] * 4) + else: + l_image1 = I1 + + s0 = l_image1.mean(dim=(-2, -1), keepdim=False) + p00 = None + s1 = None + s2 = None + s2l = None + l2_image = None + s2j1 = [] + s2j2 = [] + + for j1 in range(jmax): + c_image1 = self.convol(l_image1) + + conj = c_image1 * torch.conj(c_image1) + l_p00 = conj.mean(dim=(-2, -1)).unsqueeze(-2) + conj_mod = torch.abs(conj) + l_s1 = conj_mod.mean(dim=(-2, -1)).unsqueeze(-2) + + if s1 is None: + s1 = l_s1 + p00 = l_p00 + else: + s1 = torch.cat([s1, l_s1], dim=-2) + p00 = torch.cat([p00, l_p00], dim=-2) + + if l2_image is None: + l2_image = conj_mod.unsqueeze(1) + else: + l2_image = torch.cat([l2_image, conj_mod.unsqueeze(1)], dim=1) + + # Positive path + l2_pos = F.relu(l2_image) + pos_flat = l2_pos.reshape(-1, self.L, l2_pos.shape[-2], l2_pos.shape[-1]) + c2_image = self.convol(pos_flat, use_oriented=True).reshape( + *l2_pos.shape[:-3], self.L * self.L, l2_pos.shape[-2], l2_pos.shape[-1] + ) + conj2p = c2_image * torch.conj(c2_image) + conj2pl1 = torch.abs(conj2p) + + # Negative path + l2_neg = F.relu(-l2_image) + neg_flat = l2_neg.reshape(-1, self.L, l2_neg.shape[-2], l2_neg.shape[-1]) + c2_image_m = self.convol(neg_flat, use_oriented=True).reshape( + *l2_neg.shape[:-3], self.L * self.L, l2_neg.shape[-2], l2_neg.shape[-1] + ) + conj2m = c2_image_m * torch.conj(c2_image_m) + conj2ml1 = torch.abs(conj2m) + + l_s2 = (conj2p - conj2m).mean(dim=(-2, -1)) + l_s2l1 = (conj2pl1 - conj2ml1).mean(dim=(-2, -1)) + + if s2 is None: + s2l = l_s2 + s2 = l_s2l1 + s2j1 = list(range(l_s2.shape[-3])) + s2j2 = [j1] * l_s2.shape[-3] + else: + s2 = torch.cat([s2, l_s2l1], dim=-3) + s2l = torch.cat([s2l, l_s2], dim=-3) + s2j1.extend(list(range(l_s2.shape[-3]))) + s2j2.extend([j1] * l_s2.shape[-3]) + + if j1 != jmax - 1: + l2_image = self.smooth(l2_image) + l2_image = self.ud_grade_2(l2_image) + l_image1 = self.smooth(l_image1) + l_image1 = self.ud_grade_2(l_image1) + + return { + "s0": s0, + "p00": p00, + "s1": s1, + "s2": s2, + "s2l": s2l, + "s2j1": torch.as_tensor(s2j1, device=s0.device), + "s2j2": torch.as_tensor(s2j2, device=s0.device), + } def get_L(self): return self.L @@ -730,9 +1038,14 @@ def apply(self, data,j): # Ensure x is a torch tensor on the same device / dtype as the kernel x = torch.as_tensor(x, device=self.kernel.device, dtype=self.kernel.dtype) - weight = self.kernel.squeeze(0) # [L, K, K] + if x.dim() == 2: + x = x.unsqueeze(0).unsqueeze(0) + elif x.dim() == 3: + x = x.unsqueeze(1) + + weight = self.kernel # [L, 1, K, K] - convolved = _complex_conv2d_circular(x, weight) + convolved = _complex_conv2d_same_symmetric(x, weight) return STL_2D_Kernel_Torch(convolved,smooth_kernel=data.smooth_kernel,dg=data.dg,N0=data.N0) @@ -756,24 +1069,24 @@ def apply_smooth(self, data: STL_2D_Kernel_Torch, inplace: bool = False): Smoothed data object with same shape as input (no extra L dimension). """ x = data.array # [..., K]= - *leading, K1,K2 = x.shape + *leading, K1, K2 = x.shape # Flatten leading dims into batch dimension: (B, Ci=1, K) if leading: B = int(np.prod(leading)) else: B = 1 - x_bc = x.reshape(B, 1, K1,K2) + x_bc = x.reshape(B, 1, K1, K2) # Smooth kernel (Ci=1, Co=1, P) - w_smooth = self.kernel.abs()[0,0:1].to(device=data.device, dtype=data.dtype) + w_smooth = self.ww_SmoothT[1].to(device=data.device, dtype=data.dtype) + + y_bc = _conv2d_circular(x_bc, w_smooth) - y_bc = _conv2d_circular(x, w_smooth) - if not isinstance(y_bc, torch.Tensor): y_bc = torch.as_tensor(y_bc, device=data.device, dtype=data.dtype) - y = y_bc.reshape(*leading, K1,K2) # same shape as input x + y = y_bc.reshape(*leading, K1, K2) # same shape as input x # Copy or in-place update out = data.copy(empty=True) if not inplace else data From 6c9fc5a7b6e04c728868ea71f4205613134a4b92 Mon Sep 17 00:00:00 2001 From: Paolo Campeti <36236108+pcampeti@users.noreply.github.com> Date: Sat, 29 Nov 2025 15:07:51 +0100 Subject: [PATCH 2/8] Fix numpy dtype selection for wavelet kernel construction --- STL_main/STL_2D_Kernel_Torch.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/STL_main/STL_2D_Kernel_Torch.py b/STL_main/STL_2D_Kernel_Torch.py index 7e9a333..55310b3 100644 --- a/STL_main/STL_2D_Kernel_Torch.py +++ b/STL_main/STL_2D_Kernel_Torch.py @@ -723,8 +723,20 @@ def _wavelet_kernel(self, kernel_size: int, n_orientation: int): NORIENT = n_orientation LAMBDA = 1.0 - # Allocate real/imag components - wwc = np.zeros([NORIENT, KERNELSZ * KERNELSZ], dtype=np.float64) + # Allocate real/imag components using a numpy dtype compatible with the torch dtype + if self.dtype in (torch.float64, torch.complex128): + np_dtype = np.float64 + elif self.dtype in (torch.float32, torch.complex64): + np_dtype = np.float32 + elif self.dtype == torch.float16: + np_dtype = np.float16 + elif self.dtype == torch.bfloat16: + # numpy has limited bfloat16 support; fall back to float32 for kernel construction + np_dtype = np.float32 + else: + np_dtype = np.float32 + + wwc = np.zeros([NORIENT, KERNELSZ * KERNELSZ], dtype=np_dtype) wws = np.zeros_like(wwc) x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape( @@ -802,7 +814,7 @@ def _wavelet_kernel(self, kernel_size: int, n_orientation: int): # Orientation-expanded kernels for the second order (Cin=NORIENT, Cout=NORIENT*NORIENT) def doorientw(x: np.ndarray) -> np.ndarray: y = np.zeros( - [NORIENT * NORIENT, NORIENT, KERNELSZ, KERNELSZ], dtype=self.dtype + [NORIENT * NORIENT, NORIENT, KERNELSZ, KERNELSZ], dtype=np_dtype ) for k in range(NORIENT): start = k * NORIENT From a9aec5691e5fe9c27c51564d8697080c8ba04790 Mon Sep 17 00:00:00 2001 From: Paolo Campeti <36236108+pcampeti@users.noreply.github.com> Date: Sat, 29 Nov 2025 15:12:13 +0100 Subject: [PATCH 3/8] Handle multi-channel inputs in symmetric conv --- STL_main/STL_2D_Kernel_Torch.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/STL_main/STL_2D_Kernel_Torch.py b/STL_main/STL_2D_Kernel_Torch.py index 55310b3..6974265 100644 --- a/STL_main/STL_2D_Kernel_Torch.py +++ b/STL_main/STL_2D_Kernel_Torch.py @@ -38,7 +38,7 @@ def _conv2d_same_symmetric(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: """ *leading_dims, C, Nx, Ny = x.shape - O_c, _, wx, wy = w.shape + O_c, Cw, wx, wy = w.shape B = int(torch.prod(torch.tensor(leading_dims))) if leading_dims else 1 x4d = x.reshape(B, C, Nx, Ny) @@ -46,6 +46,13 @@ def _conv2d_same_symmetric(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: pad_x = wx // 2 pad_y = wy // 2 + # If the kernel expects a single channel but the input has multiple + # channels (e.g. orientation stacks), broadcast the kernel across the + # channel dimension to mirror FoCUS behavior of applying the same filter + # to each input channel before summing. + if Cw == 1 and C > 1: + w = w.repeat(1, C, 1, 1) + x_padded = F.pad(x4d, (pad_y, pad_y, pad_x, pad_x), mode="reflect") y = F.conv2d(x_padded, w) @@ -70,7 +77,7 @@ def _conv2d_circular(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: """ *leading_dims, C, Nx, Ny = x.shape - O_c, _, wx, wy = w.shape + O_c, Cw, wx, wy = w.shape B = int(torch.prod(torch.tensor(leading_dims))) if leading_dims else 1 x4d = x.reshape(B, C, Nx, Ny) @@ -78,6 +85,9 @@ def _conv2d_circular(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: pad_x = wx // 2 pad_y = wy // 2 + if Cw == 1 and C > 1: + w = w.repeat(1, C, 1, 1) + x_padded = F.pad(x4d, (pad_y, pad_y, pad_x, pad_x), mode="circular") y = F.conv2d(x_padded, w) From e206698b2de7c7f8adcb7cf7e511c46ae1affcc0 Mon Sep 17 00:00:00 2001 From: Paolo Campeti <36236108+pcampeti@users.noreply.github.com> Date: Sat, 29 Nov 2025 15:16:27 +0100 Subject: [PATCH 4/8] Use oriented kernels when convolving orientation stacks --- STL_main/STL_2D_Kernel_Torch.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/STL_main/STL_2D_Kernel_Torch.py b/STL_main/STL_2D_Kernel_Torch.py index 6974265..c0c0964 100644 --- a/STL_main/STL_2D_Kernel_Torch.py +++ b/STL_main/STL_2D_Kernel_Torch.py @@ -834,6 +834,8 @@ def doorientw(x: np.ndarray) -> np.ndarray: orient_real = torch.tensor(doorientw(wwc), device=self.device, dtype=self.dtype) orient_imag = torch.tensor(doorientw(wws), device=self.device, dtype=self.dtype) + oriented_kernel = torch.complex(orient_real, orient_imag) + # Complex kernel packed for convenience kernel = torch.complex(real_kernel, imag_kernel) @@ -842,6 +844,8 @@ def doorientw(x: np.ndarray) -> np.ndarray: self.ww_ImagT = [None, imag_kernel, orient_imag] self.ww_SmoothT = [None, smooth_kernel] + self.oriented_kernel = oriented_kernel + return kernel, real_kernel, imag_kernel, smooth_kernel def _bk_resize_image(self, im: torch.Tensor, noutx: int, nouty: int) -> torch.Tensor: @@ -1065,7 +1069,10 @@ def apply(self, data,j): elif x.dim() == 3: x = x.unsqueeze(1) - weight = self.kernel # [L, 1, K, K] + if x.shape[-3] == 1: + weight = self.kernel # [L, 1, K, K] + else: + weight = self.oriented_kernel # [L^2, L, K, K] to handle orientation stacks convolved = _complex_conv2d_same_symmetric(x, weight) From 2ebbcabed8405f3e334ea3407f1a3b0cea870da5 Mon Sep 17 00:00:00 2001 From: Paolo Campeti <36236108+pcampeti@users.noreply.github.com> Date: Sat, 29 Nov 2025 15:22:34 +0100 Subject: [PATCH 5/8] Handle depthwise planar convolutions --- STL_main/STL_2D_Kernel_Torch.py | 39 +++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/STL_main/STL_2D_Kernel_Torch.py b/STL_main/STL_2D_Kernel_Torch.py index c0c0964..95eaf2e 100644 --- a/STL_main/STL_2D_Kernel_Torch.py +++ b/STL_main/STL_2D_Kernel_Torch.py @@ -46,15 +46,21 @@ def _conv2d_same_symmetric(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: pad_x = wx // 2 pad_y = wy // 2 - # If the kernel expects a single channel but the input has multiple - # channels (e.g. orientation stacks), broadcast the kernel across the - # channel dimension to mirror FoCUS behavior of applying the same filter - # to each input channel before summing. - if Cw == 1 and C > 1: - w = w.repeat(1, C, 1, 1) + # Determine grouping strategy: if the input channel count matches the + # number of output channels and the kernel is single-channel, use + # depthwise convolution to keep channels independent (orientation-wise + # filtering). Otherwise fall back to standard grouped convolution with a + # broadcasted kernel when needed. + if Cw == 1 and O_c == C: + groups = C + w = w.expand(C, 1, wx, wy).contiguous() + else: + groups = 1 + if Cw == 1 and C > 1: + w = w.repeat(1, C, 1, 1) x_padded = F.pad(x4d, (pad_y, pad_y, pad_x, pad_x), mode="reflect") - y = F.conv2d(x_padded, w) + y = F.conv2d(x_padded, w, groups=groups) return y.reshape(*leading_dims, O_c, Nx, Ny) @@ -85,11 +91,16 @@ def _conv2d_circular(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: pad_x = wx // 2 pad_y = wy // 2 - if Cw == 1 and C > 1: - w = w.repeat(1, C, 1, 1) + if Cw == 1 and O_c == C: + groups = C + w = w.expand(C, 1, wx, wy).contiguous() + else: + groups = 1 + if Cw == 1 and C > 1: + w = w.repeat(1, C, 1, 1) x_padded = F.pad(x4d, (pad_y, pad_y, pad_x, pad_x), mode="circular") - y = F.conv2d(x_padded, w) + y = F.conv2d(x_padded, w, groups=groups) return y.reshape(*leading_dims, O_c, Nx, Ny) @@ -1069,10 +1080,10 @@ def apply(self, data,j): elif x.dim() == 3: x = x.unsqueeze(1) - if x.shape[-3] == 1: - weight = self.kernel # [L, 1, K, K] - else: - weight = self.oriented_kernel # [L^2, L, K, K] to handle orientation stacks + # For the operator pathway, always use the base kernel; depthwise + # grouping inside the convolution keeps per-orientation channels + # independent when more than one channel is present. + weight = self.kernel # [L, 1, K, K] convolved = _complex_conv2d_same_symmetric(x, weight) From 812588f239fd38bc55a6dae811ae7cb642ea2a60 Mon Sep 17 00:00:00 2001 From: Paolo Campeti <36236108+pcampeti@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:43:22 +0100 Subject: [PATCH 6/8] Handle orientation-aware covariance shapes --- STL_main/STL_2D_Kernel_Torch.py | 59 ++++++++++++++++++++++----------- STL_main/ST_Operator.py | 14 ++++---- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/STL_main/STL_2D_Kernel_Torch.py b/STL_main/STL_2D_Kernel_Torch.py index 95eaf2e..5bcfe0f 100644 --- a/STL_main/STL_2D_Kernel_Torch.py +++ b/STL_main/STL_2D_Kernel_Torch.py @@ -665,8 +665,13 @@ def mean(self, square=False, mask_MR=None): ########################################################################### def cov(self, data2=None, mask_MR=None, remove_mean=False): """ - Compute the covariance between data1=self and data2 on the last two - dimensions (Nx, Ny). + Covariance on the spatial dimensions while preserving orientation axes. + + The input arrays are expected to have spatial dimensions as the last + two axes. If an orientation/channel axis exists, it is assumed to be at + ``-3``; if not present, a singleton axis is inserted so that the output + keeps explicit orientation indices. This mirrors the FoCUS behavior + where covariances are computed per-orientation pair. Only works when MR == False. """ @@ -685,31 +690,47 @@ def cov(self, data2=None, mask_MR=None, remove_mean=False): raise ValueError("data2 must have the same dg as self.") y = data2.array - dims = (-2, -1) + # Ensure an explicit orientation axis just before the spatial axes + def _ensure_orient(t: torch.Tensor) -> torch.Tensor: + if t.dim() < 2: + raise ValueError("Inputs to cov must have at least 2 spatial dims.") + if t.dim() == 2: # [Nx, Ny] + return t.unsqueeze(0) + if t.dim() == 3: # [..., Nx, Ny] with no orientation + return t.unsqueeze(-3) + return t + + x_o = _ensure_orient(x) + y_o = _ensure_orient(y) + + spatial_dims = (-2, -1) if mask_MR is not None: mask = self._get_mask_at_dg(mask_MR, self.dg) + mask = _ensure_orient(mask) if remove_mean: - mx = (x * mask).mean(dim=dims, keepdim=True) - my = (y * mask).mean(dim=dims, keepdim=True) - x_c = x - mx - y_c = y - my + mx = (x_o * mask).mean(dim=spatial_dims, keepdim=True) + my = (y_o * mask).mean(dim=spatial_dims, keepdim=True) + x_c = x_o - mx + y_c = y_o - my else: - x_c = x - y_c = y - cov = (x_c * y_c.conj() * mask).mean(dim=dims) + x_c = x_o + y_c = y_o + prod = x_c.unsqueeze(-3) * y_c.conj().unsqueeze(-4) * mask.unsqueeze(-3) else: if remove_mean: - mx = x.mean(dim=dims, keepdim=True) - my = y.mean(dim=dims, keepdim=True) - x_c = x - mx - y_c = y - my + mx = x_o.mean(dim=spatial_dims, keepdim=True) + my = y_o.mean(dim=spatial_dims, keepdim=True) + x_c = x_o - mx + y_c = y_o - my else: - x_c = x - y_c = y - cov = (x_c * y_c.conj()).mean(dim=dims) - - return cov + x_c = x_o + y_c = y_o + prod = x_c.unsqueeze(-3) * y_c.conj().unsqueeze(-4) + + cov = prod.mean(dim=spatial_dims) + + return cov def get_wavelet_op(self, J=None, L=None, kernel_size=None): if L is None: diff --git a/STL_main/ST_Operator.py b/STL_main/ST_Operator.py index afdb417..abc953d 100644 --- a/STL_main/ST_Operator.py +++ b/STL_main/ST_Operator.py @@ -376,18 +376,18 @@ def apply(self, data, data_l1m_l2m={} for j2 in range(j3+1): data_l1m_l2 = self.wavelet_op.apply(data_l1m[j2],j=j3) # (Nb,Nc,L2,L3,N3) - # S3(j2,j3) = Cov(|I*psi2|*psi3, I*psi3) - data_st.S3[:,:,j2,j3,:,:] = data_l1m_l2[:,:,None].cov( #(Nb,Nc, 1,L3,N3) - data_l1 #(Nb,Nc, 1,L3,N3) - ) + # S3(j2,j3) = Cov(|I*psi2|*psi3, I*psi3) + s3_cov = data_l1m_l2.cov(data_l1) #(Nb,Nc,L2,L3) + data_st.S3[:,:,j2,j3,:,:] = s3_cov data_l1m_l2m[j2] = data_l1m_l2.modulus(inplace=False) #(Nb,Nc,L,N3) for j1 in range(j2+1): # S4(j1,j2,j3) = Cov(|I*psi1|*psi3, |I*psi2|*psi3) - data_st.S4[:,:,j1,j2,j3,:,:,:] = data_l1m_l2m[j1][:,:,None,:].cov( - data_l1m_l2m[j2][:,:,:,None] - ) + s4_cov = data_l1m_l2m[j1].cov(data_l1m_l2m[j2]) #(Nb,Nc,L1,L2) + data_st.S4[:,:,j1,j2,j3,:,:,:] = s4_cov.unsqueeze(-3).expand( + -1, -1, self.wavelet_op.L, -1, -1 + ) if data_st.DT != '2D_FFT_Torch': # Downsample at Nj3 From 8c9b85d4562eaa44fd4eeae070bca53373ecec95 Mon Sep 17 00:00:00 2001 From: Paolo Campeti <36236108+pcampeti@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:49:02 +0100 Subject: [PATCH 7/8] Fix orientation pairing in planar covariance --- STL_main/STL_2D_Kernel_Torch.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/STL_main/STL_2D_Kernel_Torch.py b/STL_main/STL_2D_Kernel_Torch.py index 5bcfe0f..1068ee3 100644 --- a/STL_main/STL_2D_Kernel_Torch.py +++ b/STL_main/STL_2D_Kernel_Torch.py @@ -702,6 +702,7 @@ def _ensure_orient(t: torch.Tensor) -> torch.Tensor: x_o = _ensure_orient(x) y_o = _ensure_orient(y) + orient_dim = x_o.dim() - 3 # index of the orientation axis spatial_dims = (-2, -1) @@ -716,7 +717,11 @@ def _ensure_orient(t: torch.Tensor) -> torch.Tensor: else: x_c = x_o y_c = y_o - prod = x_c.unsqueeze(-3) * y_c.conj().unsqueeze(-4) * mask.unsqueeze(-3) + prod = ( + x_c.unsqueeze(orient_dim) + * y_c.conj().unsqueeze(orient_dim + 1) + * mask.unsqueeze(orient_dim + 1) + ) else: if remove_mean: mx = x_o.mean(dim=spatial_dims, keepdim=True) @@ -726,7 +731,7 @@ def _ensure_orient(t: torch.Tensor) -> torch.Tensor: else: x_c = x_o y_c = y_o - prod = x_c.unsqueeze(-3) * y_c.conj().unsqueeze(-4) + prod = x_c.unsqueeze(orient_dim) * y_c.conj().unsqueeze(orient_dim + 1) cov = prod.mean(dim=spatial_dims) From f68087573c29948c542fe9a188e8b141ad93e68e Mon Sep 17 00:00:00 2001 From: Paolo Campeti <36236108+pcampeti@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:53:06 +0100 Subject: [PATCH 8/8] Pad covariance inputs with batch and channel dims --- STL_main/STL_2D_Kernel_Torch.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/STL_main/STL_2D_Kernel_Torch.py b/STL_main/STL_2D_Kernel_Torch.py index 1068ee3..fcadf59 100644 --- a/STL_main/STL_2D_Kernel_Torch.py +++ b/STL_main/STL_2D_Kernel_Torch.py @@ -702,6 +702,17 @@ def _ensure_orient(t: torch.Tensor) -> torch.Tensor: x_o = _ensure_orient(x) y_o = _ensure_orient(y) + + # Ensure there are at least two leading dims (batch, channel) + # ahead of the orientation axis so downstream code can rely on + # a consistent [Nb, Nc, L, Nx, Ny] layout even when the inputs + # were provided without explicit batch/channel axes. + while x_o.dim() < 5: + x_o = x_o.unsqueeze(0) + y_o = y_o.unsqueeze(0) + if mask_MR is not None: + mask_MR = [m.unsqueeze(0) for m in mask_MR] if isinstance(mask_MR, list) else mask_MR.unsqueeze(0) + orient_dim = x_o.dim() - 3 # index of the orientation axis spatial_dims = (-2, -1)