diff --git a/TPTBox/core/bids_constants.py b/TPTBox/core/bids_constants.py index e083035..f726074 100755 --- a/TPTBox/core/bids_constants.py +++ b/TPTBox/core/bids_constants.py @@ -64,6 +64,7 @@ "FLASH", "VF", "defacemas", + "fluroscopy", "dw", "TB1TFL", "TB1RFM", diff --git a/TPTBox/core/bids_files.py b/TPTBox/core/bids_files.py index ffd1d6a..defce2b 100755 --- a/TPTBox/core/bids_files.py +++ b/TPTBox/core/bids_files.py @@ -632,16 +632,20 @@ def rename_files(self, path: Path | str, ending=".nii.gz"): p = Path(path + "." + key) value.rename(p) - def symlink_files(self, path: Path | str, ending=".nii.gz"): + def symlink_files(self, path: Path | str, ending=".nii.gz", exist_ok=False): ending = ending if ending[0] == "." else "." + ending path = str(path) assert path.endswith(ending), f"set 'ending' to the part after the '.'\n {path} does not end with {ending}" path = path.replace(ending, "") for key, value in self.file.items(): p = Path(path + "." + key) + if os.path.islink(p): assert Path(os.readlink(p)) == value, f"{p} exists" continue + if exist_ok and p.exists(): + continue + os.symlink(value, p) def get_path_decomposed(self, file_type=None) -> tuple[Path, str, str, str]: diff --git a/TPTBox/core/dicom/dicom2nii_utils.py b/TPTBox/core/dicom/dicom2nii_utils.py index 44e0f28..a6c724c 100755 --- a/TPTBox/core/dicom/dicom2nii_utils.py +++ b/TPTBox/core/dicom/dicom2nii_utils.py @@ -239,6 +239,8 @@ def test_name_conflict(json_ob, file): if Path(file).exists(): with open(file) as f: js = json.load(f) + if "grid" in js: + del js["grid"] return js != json_ob return False diff --git a/TPTBox/core/dicom/dicom_extract.py b/TPTBox/core/dicom/dicom_extract.py index acf7d8a..0d63226 100644 --- a/TPTBox/core/dicom/dicom_extract.py +++ b/TPTBox/core/dicom/dicom_extract.py @@ -88,8 +88,42 @@ def _generate_bids_path( return fname.file["json"], fname +def dicom_to_nifti_multiframe_2d(ds, nii_path, pixel_array): + if hasattr(ds, "PixelSpacing"): + dy, dx = map(float, ds.PixelSpacing) + affine = np.eye(4) + + if hasattr(ds, "ImageOrientationPatient"): + orientation = list(map(float, ds.ImageOrientationPatient)) + row_cosines = np.array(orientation[:3]) + col_cosines = np.array(orientation[3:]) + affine[:3, 0] = row_cosines * dx + affine[:3, 1] = col_cosines * dy + else: + affine[0, 0] = dx + affine[1, 1] = dy + + if hasattr(ds, "ImagePositionPatient"): + affine[:3, 3] = np.array(list(map(float, ds.ImagePositionPatient))) + + elif hasattr(ds, "ImagerPixelSpacing"): + dy, dx = map(float, ds.ImagerPixelSpacing) + affine = np.diag([-dx, -dy, 1, 1]) + + else: + affine = np.eye(4) + + nii = nib.Nifti1Image(pixel_array.T[:, :, None], affine) + logger.on_log("Save 2D", nii_path) + nib.save(nii, nii_path) + return nii_path + + def dicom_to_nifti_multiframe(ds, nii_path): pixel_array = ds.pixel_array + if len(pixel_array.shape) == 2: + return dicom_to_nifti_multiframe_2d(ds, nii_path, pixel_array) + if len(pixel_array.shape) != 3 and len(pixel_array.shape) != 4: raise ValueError(f"Expected a shape with 3 colums not {len(pixel_array.shape)}; {pixel_array.shape=}") n_frames = pixel_array.shape[0] @@ -265,7 +299,22 @@ def _from_dicom_to_nii( override_subject_name: Callable[[dict, Path], str] | None = None, chunk=None, skip_localizer=False, + parent="rawdata", + censor_list=None, ): + if censor_list is None: + censor_list = [ + "StudyDate", + "SeriesDate", + "AcquisitionDate", + "ContentDate", + "StudyTime", + "SeriesTime", + "AcquisitionTime", + "ContentTime", + "InstanceCreationDate", + "InstanceCreationTime", + ] if chunk is None: splitted_dcm_data_l = _classic_get_grouped_dicoms(dcm_data_l) if len(splitted_dcm_data_l) != 1: @@ -282,6 +331,7 @@ def _from_dicom_to_nii( override_subject_name=override_subject_name, chunk=i, skip_localizer=skip_localizer, + parent=parent, ) outs.append(o) return outs @@ -291,6 +341,9 @@ def _from_dicom_to_nii( return None simp_json = get_json_from_dicom(dcm_data_l) + for censor_key in censor_list: + if censor_key in simp_json: + del simp_json[censor_key] json_file_name, json_bids, nii_path = _get_paths( simp_json, dcm_data_l, @@ -301,11 +354,13 @@ def _from_dicom_to_nii( map_series_description_to_file_format, override_subject_name, chunk=chunk, + parent=parent, ) if skip_localizer and json_bids.bids_format == "localizer": return logger.print(json_file_name, Log_Type.NEUTRAL, verbose=verbose) - exist = save_json(simp_json, json_file_name) + exist = save_json(simp_json, json_file_name, override=False) + # logger.on_debug(exist, Path(nii_path).exists(), nii_path) if exist and Path(nii_path).exists(): logger.print("already exists:", json_file_name, ltype=Log_Type.STRANGE, verbose=verbose) return nii_path @@ -520,6 +575,8 @@ def extract_dicom_folder( n_cpu: int | None = 1, override_subject_name: Callable[[dict, Path], str] | None = None, skip_localizer=True, + parent="rawdata", + censor_list: list | None = None, ): """ Extract DICOM files from a directory or list of directories, convert them to NIfTI format, and store the output. @@ -537,6 +594,8 @@ def extract_dicom_folder( Returns: dict: A dictionary with keys representing DICOM series and values as paths to the generated NIfTI files. """ + if censor_list is None: + censor_list = [] if not validate_slicecount: convert_dicom.settings.disable_validate_slicecount() if not validate_orientation: @@ -576,6 +635,8 @@ def process_series(key, files, parts): map_series_description_to_file_format=map_series_description_to_file_format, override_subject_name=override_subject_name, skip_localizer=skip_localizer, + parent=parent, + censor_list=censor_list, ) # Process in parallel or sequentially based on n_cpu @@ -606,8 +667,10 @@ def process_series(key, files, parts): if __name__ == "__main__": - for p in Path("/DATA/NAS/datasets_source/brain/dsa").iterdir(): - extract_dicom_folder(p, Path("/DATA/NAS/datasets_source/brain/", "dataset-DSA"), False, False, validate_slice_increment=False) + for p in Path("/media/robert/STORE N GO/DSA_Daten/").iterdir(): + extract_dicom_folder( + p, Path("/media/data/robert/datasets", "dataset-Durchleuchtung222"), False, False, validate_slice_increment=False + ) sys.exit() # s = "/home/robert/Downloads/bein/dataset-oberschenkel/rawdata/sub-1-3-46-670589-11-2889201787-2305829596-303261238-2367429497/mr/sub-1-3-46-670589-11-2889201787-2305829596-303261238-2367429497_sequ-406_mr.nii.gz" diff --git a/TPTBox/core/dicom/dicom_header_to_keys.py b/TPTBox/core/dicom/dicom_header_to_keys.py index b2ce879..c9aabff 100644 --- a/TPTBox/core/dicom/dicom_header_to_keys.py +++ b/TPTBox/core/dicom/dicom_header_to_keys.py @@ -46,6 +46,8 @@ "t2w?_fse.*": "T2w", ".*t1w?_tse.*": "T1w", ".*t1w?_vibe_tra.*": "vibe", + ".*Durchleuchtung.*": "fluroscopy", + ".*fluroscopy.*": "fluroscopy", ".*scout": "localizer", "localizer": "localizer", ".*pilot.*": "localizer", @@ -267,21 +269,33 @@ def _get(key, default=None): if modality == "ct": mri_format = "ct" elif modality == "xa": # Angiography + biplane = False if "BIPLANE A" in image_type or "SINGLE A" in image_type: keys["acq"] = "A" + biplane = True elif "BIPLANE B" in image_type or "SINGLE B" in image_type: keys["acq"] = "B" + biplane = True + derived = "DERIVED" in image_type + series_description = _get("SeriesDescription", " ").lower() # "SeriesDescription": "Durchleuchtung - gespeichert", monitor = _get("PositionerMotion", " ").lower() # ftv = _get("FrameTimeVector", None).lower() monitor = _get("PositionerMotion", " ").lower() tag = _get("DerivationDescription", " ").lower() + # "ImagerPixelSpacing" + # FrameTimeVector = _get("DerivationDescription", []) # ftv is not None - if tag == "subtraction": + if "durchleuchtung" in series_description or "fluroscopy" in series_description: + mri_format = "fluroscopy" + elif tag == "subtraction": mri_format = "DSA" if monitor == "static" and "VOLUME" not in image_type and "RECON" not in image_type else "subtraction" elif "3DRA_PROP" in image_type: mri_format = "3DRA" elif monitor == "dynamic" or "VOLUME" in image_type or "RECON" in image_type or "3DRA_PROP" in image_type: mri_format = "DSA3D" + elif biplane and derived and "VOLUME" not in image_type and "RECON" not in image_type: + ##len(FrameTimeVector) >= 1 and (monitor == "static" and "VOLUME" not in image_type and "RECON" not in image_type) + mri_format = "DSA" else: mri_format = "XA" elif modality == "mr": diff --git a/TPTBox/core/internal/elastic_deform.py b/TPTBox/core/internal/elastic_deform.py new file mode 100644 index 0000000..82c8274 --- /dev/null +++ b/TPTBox/core/internal/elastic_deform.py @@ -0,0 +1,139 @@ +import time + +import elasticdeform +import numpy as np +from numpy.typing import NDArray + +from TPTBox import NII + + +def deformed_nii( + nii_dic: dict[str, NII], + sigma: float | None = None, + points=None, + deform_factor=1.0, + deform_padding=10, + normalize=True, + joint_normalize=False, +) -> dict[str, NII]: + """ + Deform a dictionary of NII objects using random grid deformation. Requires elasticdeform. 'pip install elasticdeform' + + IMPORTANT: Normalize your image data to 0,1. The .seg property of NII shows if this is a segmentation. (NII is form our TPTBox and is a wrapper for nibable) + + This function takes a dictionary of NII objects and applies random grid deformation to each object + using specified deformation parameters or, if not provided, random parameters generated based on + the `deform_factor`. The deformed objects are returned as a dictionary. + + Args: + arr_dic (dict[str, NII]): A dictionary containing NII objects to be deformed. + sigma (float, optional): The standard deviation of the deformation field. If not provided, + it will be generated based on the `deform_factor`. + points (int, optional): The number of control points for the deformation grid. If not provided, + it will be generated based on the `deform_factor`. + deform_factor (float, optional): A factor used to determine the deformation parameters if + `sigma` and `points` are not specified. Larger values result in stronger deformations. + deform_padding (int, optional): The padding added to the deformed objects to avoid edge artifacts. + verbose (bool, optional): If True, enable verbose logging. Default is True. + + Returns: + dict[str, NII]: A dictionary where keys correspond to the input dictionary keys, and values + correspond to the deformed NII objects. + + Example: + # Deform a dictionary of NII objects using default deformation parameters + deformed_data = deformed_NII(arr_dic) + + # Deform a dictionary of NII objects with specific deformation parameters + sigma = 1.0 + points = 20 + deformed_data = deformed_NII(arr_dic, sigma=sigma, points=points) + """ + if sigma is None or points is None: + sigma, points = get_random_deform_parameter(deform_factor=deform_factor) + + print("deformation parameter sigma = ", round(sigma, 4), "; n_points = ", points) + t = time.time() + values = list(nii_dic.values()) + # Deform + if joint_normalize: + max_v = max([img.max() for img in nii_dic.values() if not img.seg]) + nii_dic = {k: img if img.seg else img.set_dtype(np.float32) / max_v for k, img in nii_dic.items()} + elif normalize: + nii_dic = {k: img if img.seg else img.set_dtype(np.float32).normalize() for k, img in nii_dic.items()} + else: + nii_dic = {k: img if img.seg else img.set_dtype(np.float32) for k, img in nii_dic.items()} + assert sigma is not None + p = deform_padding + out: list[NDArray] = elasticdeform.deform_random_grid( + [pad(v.get_array(), p=p) for v in values], + sigma=sigma, # type: ignore + points=points, + order=[0 if v.seg else 3 for v in values], # type: ignore + ) + out2: dict[str, NII] = {} + for (k, nii), arr in zip(nii_dic.items(), out, strict=True): + out2[k] = nii.set_array(arr[p:-p, p:-p, p:-p]) + print("Deformation took", round(time.time() - t, 1), "Seconds") + return out2 + + +def pad(arr, p=10): + return np.pad(arr, p, mode="reflect") + + +def get_random_deform_parameter(deform_factor: float = 1): + """ + Generate random deformation parameters for use in 3D deformation. + + This function generates random values for the deformation parameters, including 'sigma' and 'points', + based on the specified deformation factor. These parameters are used for 3D deformation operations. + + Args: + deform_factor (float, optional): A factor to control the strength of deformation. Default is 1. + + Returns: + tuple[float, int]: A tuple containing the generated 'sigma' (float) and 'points' (int) parameters. + + Example: + # Generate random deformation parameters with a deformation factor of 1 + sigma, points = get_random_deform_parameter() + + # Generate random deformation parameters with a deformation factor of 2 + sigma, points = get_random_deform_parameter(deform_factor=2) + """ + sigma = 2 + np.random.uniform() * 2.5 # 1,5 - 4.5 + min_points = 3 + max_points = 17 + if sigma < 2: + max_points = 17 + elif sigma < 1.7: + max_points = 16 + elif sigma < 2.1: + max_points = 15 + elif sigma < 2.3: + max_points = 14 + elif sigma < 2.5: + max_points = 13 + elif sigma < 2.6: + max_points = 12 + elif sigma < 2.7: + max_points = 11 + elif sigma < 2.8: + max_points = 10 + elif sigma < 3: + max_points = 9 + elif sigma < 3.5: + max_points = 8 + elif sigma < 4.0: + max_points = 7 + elif sigma < 4.3: + max_points = 6 + else: + max_points = 5 + points = np.random.randint(max_points - min_points + 1) + min_points + # Stronger + sigma *= deform_factor + # points *= deform_factor + points = max(round(points), 1) + return (sigma, points) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 5b014d2..9024057 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -4,6 +4,7 @@ import traceback import warnings import zlib +from collections import deque from collections.abc import Sequence from enum import Enum from math import ceil, floor @@ -14,6 +15,7 @@ import nibabel.orientations as nio import numpy as np from nibabel import Nifti1Header, Nifti1Image # type: ignore +from skimage.measure import marching_cubes from typing_extensions import Self from TPTBox.core import bids_files @@ -32,7 +34,9 @@ np_connected_components, np_connected_components_per_label, np_dilate_msk, + np_dilate_msk_euclid, np_erode_msk, + np_erode_msk_euclid, np_extract_label, np_fill_holes, np_fill_holes_global_with_majority_voting, @@ -66,6 +70,7 @@ from TPTBox.logger.log_file import Log_Type if TYPE_CHECKING: + from stl.mesh import Mesh from torch import device MODES = Literal["constant", "nearest", "reflect", "wrap"] _unpacked_nii = tuple[np.ndarray, AFFINE, nib.nifti1.Nifti1Header] @@ -787,30 +792,52 @@ def pad_to(self,target_shape:list[int]|tuple[int,int,int] | Self, mode:MODES="co s = s.apply_crop(tuple(crop),inplace=inplace) return s.apply_pad(padding,inplace=inplace,mode=mode) - def apply_pad(self,padd:Sequence[tuple[int|None,int]]|None,mode:MODES="constant",inplace = False,verbose:logging=True): + + def apply_pad( + self, + padd: Sequence[tuple[int | None, int | None]] | int | None, + mode: MODES = "constant", + inplace=False, + verbose: logging = True + ): #TODO add other modes #TODO add testcases and options for modes - if padd is None: + if padd is None or padd == 0: return self if inplace else self.copy() - transform = np.eye(self.dims+1, dtype=int) + + if isinstance(padd, (int, float)): + padd = int(padd) + padd = ((padd, padd),) * self.dims + assert len(padd) == self.dims - for i, (before,_) in enumerate(padd): - #transform[i, i] = pad_slice.step if pad_slice.step is not None else 1 - transform[i, 3] = -before if before is not None else 0 - while len(padd) < len(self.shape): - padd = (*tuple(padd), (0, 0)) - affine = self.affine.dot(transform) + # Replace None with 0 + padd = tuple((b or 0, a or 0) for b, a in padd) + + # Extend for non-spatial dims + padd = padd + ((0, 0),) * (len(self.shape) - len(padd)) + + # Build affine transform + transform = np.eye(self.dims + 1, dtype=float) + for i, (before, _) in enumerate(padd[:self.dims]): + transform[i, -1] = -before + + affine = self.affine @ transform + args = {} if mode == "constant": - args["constant_values"]=self.get_c_val() - log.print(f"Padd {padd}; {mode=}, {args}",verbose=verbose) - arr = np.pad(self.get_array(),padd,mode=mode,**args) # type: ignore + args["constant_values"] = self.get_c_val() + + log.print(f"Padd {padd}; {mode=}, {args}", verbose=verbose) + + arr = np.pad(self.get_array(), padd, mode=mode, **args) + + nii = (arr, affine, self.header) - nii:_unpacked_nii = (arr,affine,self.header) if inplace: self.nii = nii return self + return self.copy(nii) def rescale_and_reorient(self, axcodes_to=None, voxel_spacing=(-1, -1, -1), verbose:logging=True, inplace=False,c_val:float|None=None,mode:MODES='nearest'): @@ -901,6 +928,8 @@ def resample_from_to(self, to_vox_map:Image_Reference|Has_Grid|tuple[SHAPE,AFFIN Returns: NII: """ '''''' + if to_vox_map is None: + return self if inplace else self.copy() c_val = self.get_c_val(c_val) if isinstance(to_vox_map,Has_Grid): mapping = to_vox_map.to_gird() @@ -909,11 +938,41 @@ def resample_from_to(self, to_vox_map:Image_Reference|Has_Grid|tuple[SHAPE,AFFIN if isinstance(mapping,Has_Grid) and mapping.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0): log.print(f"resample_from_to skipped; already in space: {self}",verbose=verbose) return self if inplace else self.copy() + + #m1 = mapping.make_empty_POI().reorient(self.orientation) + #if m1.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0): + # log.print(f"resample_from_to only need reorientation; {self.orientation}",verbose=verbose) + # return self.reorient(mapping.orientation,inplace=inplace) + #if self.orientation == mapping.orientation and self.zoom == mapping.zoom: + # shift = (np.array(self.origin) - np.array(m1.origin)) / np.array(m1.zoom) + # if np.allclose(shift, np.round(shift), atol=1e-6): + # self = self.reorient(mapping.orientation,inplace=inplace) # noqa: PLW0642 + # shift = (np.array(self.origin) - np.array(mapping.origin)) / np.array(mapping.zoom) + # shift = np.round(shift).astype(int) + # src_shape = np.array(mapping.shape) + # dst_shape = np.array(self.shape) + # # padding before = how much dst starts before src + # pad_before = np.maximum(-shift, 0) + # + # # where src ends inside dst + # src_end_in_dst = shift + src_shape + # # padding after = remaining dst size after src + # pad_after = np.maximum(dst_shape - src_end_in_dst, 0) + # pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after)) + # ret = self.apply_pad(pad, mode=mode) + # + # log.print(f"resample_from_to only needs padding/cropping {pad}, ",verbose=verbose,) + # ret.assert_affine(mapping,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0) + # return ret + + assert mapping is not None log.print(f"resample_from_to: {self} to {mapping}",verbose=verbose) if order is None: order = 0 if self.seg else 3 nii = _resample_from_to(self, mapping,order=order, mode=mode,align_corners=align_corners) + + if inplace: self.nii = nii return self @@ -1158,6 +1217,20 @@ def to_ants(self): def to_simpleITK(self): from TPTBox.core.sitk_utils import nii_to_sitk return nii_to_sitk(self) + @classmethod + def from_deepali(cls, img,seg=False): + try: + from deepali.data import Image as deepaliImage + except Exception: + log.print_error() + log.on_fail("run 'pip install hf-deepali' to install deepali") + raise + img_ : deepaliImage =img + grid = cls.from_deepali_grid(img_.grid()) + + arr = img_.data.squeeze().cpu().detach().numpy() + arr = np.transpose(arr, axes=tuple(reversed(range(arr.ndim)))) + return NII((nib.Nifti1Image(arr,grid.affine)),seg=seg) def to_deepali(self,align_corners: bool = True,dtype=None,device:device|str = "cpu"): import torch @@ -1236,7 +1309,82 @@ def erode_msk(self, n_pixel: int = 5, labels: LABEL_REFERENCE = None, connectivi def erode_msk_(self, n_pixel:int = 5, labels: LABEL_REFERENCE = None, connectivity: int=3, verbose:logging=True,border_value=0,use_crop=True,ignore_direction:DIRECTIONS|int|None=None): return self.erode_msk(n_pixel=n_pixel, labels=labels, connectivity=connectivity, inplace=True, verbose=verbose,border_value=border_value,use_crop=use_crop,ignore_direction=ignore_direction) + def erode_msk_euclid( + self, + n_pixel: int = 5, + labels: LABEL_REFERENCE = None, + mask: Self | None = None, + inplace=False, + verbose: logging = True, + use_crop=True + ): + """ + Euclidean erodes (in voxel space) a segmentation mask by the specified number. + + Args: + n_pixel (int, optional): The number of voxels to erode the mask by. Defaults to 5. + labels (list[int], optional): Labels that should be eroded. If None, will erode all labels (not including zero!). + mask (NII, optional): If set, after operation, will zero out everything based on this mask. + inplace (bool, optional): Whether to modify the mask in place or return a new object. Defaults to False. + verbose (bool, optional): Whether to print a message. Defaults to True. + use_crop: Speed up computation by cropping/un-cropping. + + Returns: + NII: The eroded mask. + + Notes: + Uses Euclidean distance transform inside the foreground. + Runtime is independent of n_pixel and len(labels). + For n_pixel=1 this is similar to connectivity=1 erosion. + """ + assert self.seg + log.print("erode mask", end="\r", verbose=verbose) + + msk_i_data = self.get_seg_array() + mask_ = mask.get_seg_array() if mask is not None else None + + out = np_erode_msk_euclid( + arr=msk_i_data, + n_pixel=n_pixel, + labels=labels, + use_crop=use_crop, + mask=mask_ + ) + + out = out.astype(self.dtype) + + log.print("Mask euclidean eroded by", n_pixel, "voxels", verbose=verbose) + + return self.set_array(out, inplace=inplace) + def dilate_msk_euclid(self, n_pixel: int = 5, labels: LABEL_REFERENCE = None,mask: Self | None = None, inplace=False, verbose:logging=True,use_crop=True): + """ + euclidean Dilates (in voxel space) a segmentation mask by the specified number. + + Args: + n_pixel (int, optional): The number of voxels to dilate the mask by. Defaults to 5. + labels (list[int], optional): Labels that should be dilated. If None, will dilate all labels (not including zero!) + mask (NII, optional): If set, after each iteration, will zero out everything based on this mask + inplace (bool, optional): Whether to modify the mask in place or return a new object. Defaults to False. + verbose (bool, optional): Whether to print a message indicating that the mask was dilated. Defaults to True. + use_crop: speed up computation by cropping and un-cropping the segmentation. Minor overhead if the segmentation fills most of the image + Returns: + NII: The dilated mask. + + Notes: + The method uses euclidean dilation to dilate the mask by the specified number of voxels. + For n_pixel=1 dilate_msk_euclid and dilate_msk/connectivity=1 are equivalent. + This will algorithm runtime is independent of n_pixel and len(labels) unlike dilate_msk + + """ + assert self.seg + log.print("dilate mask",end='\r',verbose=verbose) + msk_i_data = self.get_seg_array() + mask_ = mask.get_seg_array() if mask is not None else None + out = np_dilate_msk_euclid(arr=msk_i_data, n_pixel=n_pixel,labels=labels,use_crop=use_crop,mask=mask_) + out = out.astype(self.dtype) + log.print("Mask euclidean dilated by", n_pixel, "voxels",verbose=verbose) + return self.set_array(out,inplace=inplace) def dilate_msk(self, n_pixel: int = 5, labels: LABEL_REFERENCE = None, connectivity: int = 3, mask: Self | None = None, inplace=False, verbose:logging=True,use_crop=True, ignore_direction:DIRECTIONS|int|None=None): """ Dilates the binary segmentation mask by the specified number of voxels. @@ -1610,7 +1758,7 @@ def truncate_labels_beyond_reference( ): return self.truncate_labels_beyond_reference_(idx,not_beyond,fill,axis,inclusion,inplace=inplace) - def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|str|None=None,max_depth=None): + def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|str|None=None,max_depth=None, _do_crop=True): """ Expands labels from self_mask into regions of reference_mask == 1 via breadth-first diffusion. @@ -1623,8 +1771,14 @@ def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|s ndarray: Updated label mask. """ self.assert_affine(reference_mask) - self_mask = self.compute_surface_mask().get_seg_array().copy() - self_mask_org = self.get_seg_array().copy() + if _do_crop: + crop = reference_mask.compute_crop(0,5) + s = self.apply_crop(crop) + reference_mask = reference_mask.apply_crop(crop) + else: + s = self + self_mask = s.compute_surface_mask().get_seg_array().copy() + self_mask_org = s.get_seg_array().copy() ref_mask = np.clip(reference_mask.get_seg_array(), 0, 1) ref_mask[self_mask_org != 0] = 0 searched = np.clip(self_mask,0,1).astype(np.uint8) @@ -1644,13 +1798,14 @@ def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|s else: raise NotImplementedError(axis) - search = [] + search = deque() coords = np.where(self_mask != 0) def _add_idx(x,y,z,v,d): for x1,y1,z1 in kernel: a = x+x1 b = y+y1 c = z+z1 + if a < 0 or b < 0 or c < 0: continue if a >= self_mask.shape[0] or b >= self_mask.shape[1] or c >= self_mask.shape[2]: @@ -1663,28 +1818,37 @@ def _add_idx(x,y,z,v,d): def _infect(a,b,c,v,d): if d-1 == max_depth: return - if searched[a,b,c] != 0: + if searched[x,y,z] != 0: return - if ref_mask[a,b,c] == 0: + if ref_mask[x,y,z] == 0: return #print(a,b,c) searched[a,b,c] = 1 self_mask[a,b,c] = v - _add_idx(x,y,z,v,d) + _add_idx(a,b,c,v,d) from tqdm import tqdm for x,y,z in tqdm(zip(coords[0],coords[1],coords[2]),total=len(coords[0]),disable=not verbose,desc="Collecting Surface"): _add_idx(x,y,z,self_mask[x,y,z],0) while len(search) != 0: search2 = search - search = [] - for x,y,z,v,d in tqdm(search2,disable=not verbose,desc="infect"): + search = deque() + for _ in tqdm(range(len(search2)),disable=not verbose,desc="infect"): + x,y,z,v,d = search2.popleft() _infect(x,y,z,v,d+1) self_mask[self_mask == 0] = self_mask_org[self_mask == 0] + if _do_crop: + if inplace: + self[crop] = self_mask + return self + else: + arr = self.get_array() + arr[crop] = self_mask + self_mask = arr return self.set_array(self_mask,inplace=inplace) - def infect_(self: NII, reference_mask: NII,verbose=True,axis:int|str|None=None): - return self.infect(reference_mask, inplace=True,verbose=verbose,axis=axis) + def infect_(self: NII, reference_mask: NII,verbose=True,axis:int|str|None=None,_do_crop=True): + return self.infect(reference_mask, inplace=True,verbose=verbose,axis=axis,_do_crop=_do_crop) def map_labels(self, label_map:LABEL_MAP , verbose:logging=True, inplace=False): """ @@ -1749,6 +1913,9 @@ def clone(self): def save(self,file:str|Path,make_parents=True,verbose:logging=True, dtype = None): if make_parents: Path(file).parent.mkdir(0o777,exist_ok=True,parents=True) + if str(file).endswith(".nrrd"): + return self.save_nrrd(file,verbose=verbose) + arr = self.get_array() if not self.seg else self.get_seg_array() if isinstance(arr,np.floating) and self.seg: self.set_dtype_("smallest_uint") @@ -1762,7 +1929,6 @@ def save(self,file:str|Path,make_parents=True,verbose:logging=True, dtype = None # 1 means Scanner coordinate system # 2 means align (to something) coordinate system out.header["qform_code"] = 2 if self.seg else 1 - nib.save(out, file) #type: ignore log.print(f"Save {file} as {out.get_data_dtype()}",verbose=verbose,ltype=Log_Type.SAVE) @@ -1785,11 +1951,178 @@ def save_nrrd(self:Self, file: str | Path|bids_files.BIDS_FILE,make_parents=True raise ImportError("The `pynrrd` package is required but not installed. Install it with `pip install pynrrd`." ) from None if isinstance(file, bids_files.BIDS_FILE): file = file.file['nrrd'] - from TPTBox.core.internal.slicer_nrrd import save_slicer_nrrd save_slicer_nrrd(self,file,make_parents=make_parents,verbose=verbose,**args) + def to_stls( + self: NII, + out_path: Path | dict[int, Path] | None = None, + bb: tuple | None = None, + to_world: bool = True, + include_normals: bool = False, + number_path: bool | None = None, + ) -> dict[int, Mesh]: + """ + Convert all labels in a segmentation into STL meshes. + + This function iterates over all unique labels in the segmentation and + applies `to_stl_single` to each label independently. + + Args: + seg (NII): + Segmentation object containing one or more labels. + out_path (Path | dict[int, Path] | None, optional): + Output specification: + - Path → save all meshes into the same directory or file pattern + - dict[label, Path] → per-label output paths + - None → do not save meshes + bb (tuple | None, optional): + Optional bounding box (e.g., slices). If provided and `to_world=False`, + vertex coordinates are shifted by the bounding box start indices. + to_world (bool, optional): + If True, transform vertices from voxel coordinates into world + coordinates using `seg.affine`. Defaults to True. + include_normals (bool, optional): + If True, compute per-face normals for each mesh using + `mesh.Mesh.update_normals()`. Defaults to False. + number_path (bool | None, optional): + Controls filename numbering when saving: + - If None (default): + Automatically set to True if `out_path` is not a dict, + and False otherwise. + - If True: + Append the label to the output filename. + - If False: + Do not modify the filename. + + Returns: + dict[int, mesh.Mesh]: + Dictionary mapping each label to its corresponding STL mesh. + + Notes: + - Each label is processed independently via `to_stl_single`. + - Padding is applied internally to ensure closed surfaces. + - STL format stores only triangle geometry and per-face normals; + it does not support per-vertex attributes such as scalar values. + - If `to_world=True`, all meshes are returned in physical space + (e.g., millimeters). + """ + ret = {} + + # Resolve default numbering behavior + if number_path is None: + number_path = not isinstance(out_path, dict) + + for i in self.unique(): + ret[i] = self.to_stl( + label=i, out_path=out_path, bb=bb, to_world=to_world, include_normals=include_normals, number_path=number_path + ) + + return ret + + + def to_stl( + self: NII, + label: int, + out_path: Path | dict[int, Path] | None = None, + bb: tuple | None = None, + to_world: bool = True, + include_normals: bool = False, + number_path=False, + ) -> Mesh: + """ + Convert a binary segmentation label into an STL surface mesh using marching cubes. + + The function extracts a single label from a segmentation, runs marching cubes + to generate a triangular surface mesh, and optionally transforms the vertices + into world (physical) coordinates using the NIfTI affine. + + Args: + seg (NII): + Segmentation object containing a 3D mask. + label (int, optional): + Label value to extract from the segmentation. Defaults to 1. + out_path (Path | dict[int, Path] | None, optional): + Output specification: + - Path → save mesh to this file + - dict[label, Path] → per-label output path + - None → do not save mesh + bb (tuple | None, optional): + Optional bounding box (e.g., slices). If provided and `to_world=False`, + vertex coordinates are shifted by the bounding box start indices. + to_world (bool, optional): + If True, transform vertices from voxel coordinates into world + coordinates using `seg.affine`. Defaults to True. + include_normals (bool, optional): + If True, compute and include per-face normals in the returned mesh + using `mesh.Mesh.update_normals()`. Note that STL supports only + one normal per face. Defaults to False. + number_path (bool, optional): + If True, append the label to the output filename when saving. + Defaults to False. + + Returns: + mesh.Mesh: + The generated STL mesh. If `include_normals=True`, normals are stored + in `mesh.normals` (per face). + + Notes: + - Marching cubes is applied to a padded volume to ensure closed surfaces + at the segmentation boundaries. + - Vertex coordinates are initially in voxel space and shifted to account + for padding. + - If `to_world=True`, vertices are transformed to physical space (e.g. mm) + using the affine matrix of the input segmentation. + - STL format stores only triangle geometry and per-face normals; it does + not support per-vertex attributes such as scalar values from marching cubes. + """ + + from stl import mesh + + seg = self.extract_label(label) + # Prepare binary mask + seg_arr = np.pad(seg.clamp(0, 1).get_array(), 1) + # Marching cubes (voxel coordinates) + try: + verts, faces, normals, values = marching_cubes(seg_arr, gradient_direction="ascent", step_size=1) + except RuntimeError as e: + raise RuntimeError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None + # Remove padding offset (since we padded by 1 voxel) + verts -= 1 + # Apply bounding box offset (still voxel space) + if bb is not None and not to_world: + verts += np.array([b.start for b in bb]) + # Convert to world coordinates using affine + if to_world: + affine = self.affine # (4, 4) + verts_h = np.c_[verts, np.ones(len(verts))] # homogeneous coords + verts = (affine @ verts_h.T).T[:, :3] + + # Build STL mesh + cube: mesh.Mesh = mesh.Mesh(np.zeros(faces.shape[0], dtype=mesh.Mesh.dtype)) + for i, f in enumerate(faces): + cube.vectors[i] = verts[f] + + # Save if requested + if out_path is not None: + out_path = out_path.get(label) if isinstance(out_path, dict) else out_path + + if out_path is not None: + out_path = Path(out_path) + if out_path.is_dir(): + out_path = out_path / f"mask_{label}.stl" + elif number_path: + out_path.with_name(f"{out_path.stem}_{label}.stl") + log.on_save(f"Saving STL to {out_path}") + out_path.parent.mkdir(exist_ok=True) + cube.save(str(out_path)) + + if include_normals: + cube.update_normals() + + return cube + def __str__(self) -> str: return f"{super().__str__()}, seg={self.seg}" # type: ignore def __repr__(self)-> str: diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 900b054..d3507dd 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -20,6 +20,7 @@ from scipy.ndimage import ( binary_erosion, center_of_mass, + distance_transform_edt, gaussian_filter, generate_binary_structure, ) @@ -314,6 +315,97 @@ def np_dice(seg: np.ndarray, gt: np.ndarray, binary_compare: bool = False, label return dice +def np_erode_msk_euclid(arr: np.ndarray, n_pixel: int = 3, use_crop=True, labels=None, mask=None): + """ + Fast approximate erosion: + - shrinks segmentation by k voxels + - removes voxels close to background + """ + if use_crop: + arr_bin = arr.copy() + if labels is not None: + arr_bin[np.isin(arr_bin, labels, invert=True)] = 0 + crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False) + arrc = arr[crop] + else: + arrc = arr + if labels is not None: + arrc = arrc.copy() + arrc[np.isin(arrc, labels, invert=True)] = 0 + + if mask is not None: + mask = mask.copy() + mask[mask != 0] = 1 + if use_crop: + mask = mask[crop] + + foreground = arrc > 0 + + # distance inside foreground to nearest background + dist = distance_transform_edt(foreground) + + # copy original + out = arrc.copy() + + # remove voxels within erosion distance + erode_mask = (dist <= n_pixel) & foreground + out[erode_mask] = 0 + + if mask is not None: + out[mask == 0] = 0 + + if use_crop: + arr[crop][arrc != 0] = out[arrc != 0] + return arr + + arr[arrc != 0] = out[arrc != 0] + return arr + + +def np_dilate_msk_euclid(arr: np.ndarray, n_pixel: int = 3, use_crop=True, labels=None, mask=None): + """ + Fast approximate dilation: + - expands segmentation by k voxels + - assigns new voxels to nearest label + """ + if use_crop: + arr_bin = arr.copy() + if labels is not None: + arr_bin[np.isin(arr_bin, labels, invert=True)] = 0 + crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False) + arrc = arr[crop] + else: + arrc = arr + if labels is not None: + arrc = arrc.copy() + arrc[np.isin(arr_bin, labels, invert=True)] = 0 + if mask is not None: + mask[mask != 0] = 1 + if use_crop: + mask = mask[crop] + foreground = arrc > 0 + + # distance + nearest label indices + dist, indices = distance_transform_edt(~foreground, return_indices=True) + + # copy original + out = arrc.copy() + + # mask of voxels within dilation range + dist_mask = (dist <= n_pixel) & (~foreground) + + # assign nearest label + nearest_labels = arrc[tuple(indices)] + out[dist_mask] = nearest_labels[dist_mask] + if mask is not None: + out[mask == 0] = 0 + if use_crop: + arr[crop][out != 0] = out[out != 0] + return arr + arr[out != 0] = out[out != 0] + return arr + + def np_dilate_msk( arr: np.ndarray, label_ref: LABEL_REFERENCE = None, @@ -506,16 +598,8 @@ def np_calc_crop_around_centerpoint( cutout_coords_slices = tuple([slice(cutout_coords[i], cutout_coords[i + 1]) for i in range(0, n_dim * 2, 2)]) arr_cut = arr[cutout_coords_slices] - arr_cut = np.pad( - arr_cut, - tuple(padding), - ) - return ( - arr_cut, - cutout_coords_slices, - tuple(padding), - # tuple([slice(padding[i][0], padding[i][1]) for i in range(n_dim)]), - ) + arr_cut = np.pad(arr_cut, tuple(padding)) + return (arr_cut, cutout_coords_slices, tuple(padding)) def np_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray = 0, raise_error=True) -> tuple[slice, ...]: diff --git a/TPTBox/core/poi.py b/TPTBox/core/poi.py index 9233296..2146462 100755 --- a/TPTBox/core/poi.py +++ b/TPTBox/core/poi.py @@ -420,17 +420,9 @@ def reorient(self, axcodes_to: AX_CODES = ("P", "I", "R"), decimals=ROUNDING_LVL """ ctd_arr = np.transpose(np.asarray(list(self.centroids.values()))) v_list = list(self.centroids.keys()) - if ctd_arr.shape[0] == 0: - log.print( - "No pois present", - verbose=verbose if not isinstance(verbose, bool) else True, - ltype=Log_Type.WARNING, - ) - return self if inplace else self.copy() ornt_fr = nio.axcodes2ornt(self.orientation) # original poi orientation ornt_to = nio.axcodes2ornt(axcodes_to) - if (ornt_fr == ornt_to).all(): log.print("ctd is already rotated to image with ", axcodes_to, verbose=verbose) return self if inplace else self.copy() @@ -446,17 +438,21 @@ def reorient(self, axcodes_to: AX_CODES = ("P", "I", "R"), decimals=ROUNDING_LVL shape = _shape assert shape is not None, "Require shape information for flipping dimensions. Set self.shape or use reorient_to" shp = np.asarray(shape) - ctd_arr[perm] = ctd_arr.copy() - for ax in trans: - if ax[1] == -1: - size = shp[ax[0]] - ctd_arr[ax[0]] = np.around(size - ctd_arr[ax[0]], decimals) - 1 - points = POI_Descriptor() - ctd_arr = np.transpose(ctd_arr).tolist() - for v, point in zip_strict(v_list, ctd_arr): - points[v] = tuple(point) - - log.print("[*] Centroids reoriented from", nio.ornt2axcodes(ornt_fr), "to", axcodes_to, verbose=verbose) + if ctd_arr.shape[0] == 0: + log.print("No pois present", verbose=verbose, ltype=Log_Type.WARNING) + points = self.centroids if inplace else self.centroids.copy() + else: + ctd_arr[perm] = ctd_arr.copy() + for ax in trans: + if ax[1] == -1: + size = shp[ax[0]] + ctd_arr[ax[0]] = np.around(size - ctd_arr[ax[0]], decimals) - 1 + points = POI_Descriptor() + ctd_arr = np.transpose(ctd_arr).tolist() + for v, point in zip_strict(v_list, ctd_arr): + points[v] = tuple(point) + + log.print("[*] Centroids reoriented from", nio.ornt2axcodes(ornt_fr), "to", axcodes_to, verbose=verbose) if self.zoom is not None: zoom_i = np.array(self.zoom) zoom_i[perm] = zoom_i.copy() @@ -643,16 +639,47 @@ def make_point_cloud_nii(self, affine=None, s=8, sphere=False): from math import ceil, floor if sphere: - from tqdm import tqdm + zoom = np.asarray(self.zoom) - for region, subregion, (x, y, z) in tqdm(self.items(), total=len(self)): - coords = np.ogrid[: self.shape[0], : self.shape[1], : self.shape[2]] - zoom = self.zoom - distance = np.sqrt( - ((coords[0] - int(x)) * zoom[0]) ** 2 + ((coords[1] - int(y)) * zoom[1]) ** 2 + ((coords[2] - int(z)) * zoom[2]) ** 2 - ) - arr += np.asarray(region * (distance <= s / 2), dtype=np.uint16) - arr2 += np.asarray(subregion * (distance <= s / 2), dtype=np.uint16) + # sphere radius in mm + radius = s / 2 + + # kernel size in voxels + rx = int(np.ceil(radius / zoom[0])) + ry = int(np.ceil(radius / zoom[1])) + rz = int(np.ceil(radius / zoom[2])) + + # create local sphere kernel ONCE + gx, gy, gz = np.ogrid[-rx : rx + 1, -ry : ry + 1, -rz : rz + 1] + sphere_mask = ((gx * zoom[0]) ** 2 + (gy * zoom[1]) ** 2 + (gz * zoom[2]) ** 2) <= radius**2 + + for region, subregion, (x, y, z) in self.items(): + x, y, z = round(x), round(y), round(z) # noqa: PLW2901 + + # image bounds + x0 = max(x - rx, 0) + x1 = min(x + rx + 1, self.shape[0]) + + y0 = max(y - ry, 0) + y1 = min(y + ry + 1, self.shape[1]) + + z0 = max(z - rz, 0) + z1 = min(z + rz + 1, self.shape[2]) + + # kernel bounds + kx0 = x0 - (x - rx) + kx1 = kx0 + (x1 - x0) + + ky0 = y0 - (y - ry) + ky1 = ky0 + (y1 - y0) + + kz0 = z0 - (z - rz) + kz1 = kz0 + (z1 - z0) + + local_mask = sphere_mask[kx0:kx1, ky0:ky1, kz0:kz1] + + arr[x0:x1, y0:y1, z0:z1][local_mask] = region + arr2[x0:x1, y0:y1, z0:z1][local_mask] = subregion else: for region, subregion, (x, y, z) in self.items(): arr[ @@ -966,7 +993,10 @@ def calc_poi_from_subreg_vert( if _vert_ids is None: _vert_ids = vert_msk.unique() - from TPTBox.core.poi_fun.vertebra_pois_non_centroids import add_prerequisites, compute_non_centroid_pois # noqa: PLC0415 + from TPTBox.core.poi_fun.vertebra_pois_non_centroids import ( # noqa: PLC0415 + add_prerequisites, + compute_non_centroid_pois, + ) subreg_id = add_prerequisites(_int2loc(subreg_id if isinstance(subreg_id, Sequence) else [subreg_id])) # type: ignore diff --git a/TPTBox/core/poi_fun/poi_abstract.py b/TPTBox/core/poi_fun/poi_abstract.py index 48e6ba2..2466793 100755 --- a/TPTBox/core/poi_fun/poi_abstract.py +++ b/TPTBox/core/poi_fun/poi_abstract.py @@ -524,7 +524,7 @@ def __contains__(self, key: POI_ID) -> bool: def __getitem__(self, key: POI_ID) -> COORDINATE: return tuple(self.centroids[key]) - def __setitem__(self, key: POI_ID, value: tuple[float, float, float] | Sequence[float]): + def __setitem__(self, key: POI_ID, value: tuple[float, float, float] | Sequence[float] | np.ndarray): if len(value) != DIMENSIONS: raise ValueError(value) self.centroids[key] = tuple(value) diff --git a/TPTBox/core/poi_fun/poi_global.py b/TPTBox/core/poi_fun/poi_global.py index 91fc1f6..2613e9b 100755 --- a/TPTBox/core/poi_fun/poi_global.py +++ b/TPTBox/core/poi_fun/poi_global.py @@ -115,8 +115,8 @@ def to_other_poi(self, ref: poi.POI | Self): elif isinstance(ref, Self): return self.to_cord_system(ref.itk_coords) - def to_global(self): - return self + def to_global(self, itk_coords: bool | None = None): + return self.to_cord_system(itk_coords) if itk_coords is not None else self.copy() def to_local(self, msk: Has_Grid): return self.resample_from_to(msk) @@ -170,11 +170,8 @@ def copy(self, centroids: POI_Descriptor | None = None) -> Self: def load(cls, poi: poi.POI_Reference, itk_coords: bool | None = None) -> Self: poi_obj = load_poi(poi) - if not poi_obj.is_global: + if not poi_obj.is_global or itk_coords is not None: poi_obj = poi_obj.to_global(itk_coords if itk_coords is not None else False) # type: ignore - if itk_coords is not None: - assert itk_coords == poi_obj.itk_coords, "not implemented swichting to/from itk_coords to nii " - return poi_obj # type: ignore def save( diff --git a/TPTBox/core/poi_fun/save_load.py b/TPTBox/core/poi_fun/save_load.py index 3015f18..54e87d3 100644 --- a/TPTBox/core/poi_fun/save_load.py +++ b/TPTBox/core/poi_fun/save_load.py @@ -130,6 +130,8 @@ def convert(o): return float(o) if isinstance(o, np.ndarray): return o.tolist() + if isinstance(o, Path): + return str(o.absolute()) raise TypeError(type(o)) with open(out_path, "w") as f: @@ -253,6 +255,8 @@ def _open_file(ctd_path: Union[Path, str, bids_files.BIDS_FILE]) -> dict | list: pass # not JSON → continue except OSError as e: raise OSError(f"Could not open file: {path}") from e + except UnicodeDecodeError as e: + raise OSError(f"Could not open file: {path}") from e # --- 2) try landmark TXT --- try: diff --git a/TPTBox/core/vert_constants.py b/TPTBox/core/vert_constants.py index 0f186d6..688fba5 100755 --- a/TPTBox/core/vert_constants.py +++ b/TPTBox/core/vert_constants.py @@ -433,6 +433,21 @@ def bone(cls): Full_Body_Instance.phalanges_left, ] + @classmethod + def feet(cls): + return [ + Full_Body_Instance.talus_right, + Full_Body_Instance.talus_left, + Full_Body_Instance.calcaneus_right, + Full_Body_Instance.calcaneus_left, + Full_Body_Instance.tarsals_right, + Full_Body_Instance.tarsals_left, + Full_Body_Instance.metatarsals_right, + Full_Body_Instance.metatarsals_left, + Full_Body_Instance.phalanges_right, + Full_Body_Instance.phalanges_left, + ] + @classmethod def lung_system(cls): return [ @@ -647,8 +662,8 @@ def get_to_VIBESeg(cls): Full_Body_Instance.phalanges_right.value: 72, Full_Body_Instance.phalanges_left.value: 72, Full_Body_Instance.trachea.value: 16, - Full_Body_Instance.lung_right.value: 910, - Full_Body_Instance.lung_left.value: 910, + Full_Body_Instance.lung_right.value: 73, + Full_Body_Instance.lung_left.value: 73, Full_Body_Instance.heart.value: 24, Full_Body_Instance.spleen.value: 1, Full_Body_Instance.kidney_right.value: 2, @@ -1017,7 +1032,7 @@ def save_as_name(cls) -> bool: Inferior_Articular_Right = 48 Vertebra_Corpus_border = 49 # actual corpus body Vertebra_Corpus = 50 - Dens_axis = 51 # TODO Unused. Should be in C2 + Dens_axis = 51 # only in C2 and CT but not MRI Vertebral_Body_Endplate_Superior = 52 Vertebral_Body_Endplate_Inferior = 53 # Articulate_Process_Facet_Joint (Used anywhere?) @@ -1196,3 +1211,12 @@ def vert_subreg_labels(with_border: bool = True) -> list[Location]: ] conversion_poi2text = {k: v for v, k in conversion_poi.items()} + + +list_of_all_enums = [ + Location, + Vertebra_Instance, + Lower_Body, + Full_Body_Instance, + Full_Body_Instance_Vibe, +] diff --git a/TPTBox/logger/log_file.py b/TPTBox/logger/log_file.py index ad8592d..6398e89 100755 --- a/TPTBox/logger/log_file.py +++ b/TPTBox/logger/log_file.py @@ -211,8 +211,15 @@ def on_warning(self, *text, end="\n", verbose: bool | None = None, **qargs): def on_text(self, *text, end="\n", verbose: bool | None = None, **qargs): self.print(*text, end=end, ltype=Log_Type.TEXT, verbose=verbose, **qargs) + # same logging as the python loger for drop in replacement + def warning(self, *text, end="\n", verbose: bool | None = None, **qargs): + return self.on_warning(*text, end=end, verbose=verbose, **qargs) + + def error(self, *text, end="\n", verbose: bool | None = None, **qargs): + return self.on_fail(*text, end=end, verbose=verbose, **qargs) + def info(self, *text, end="\n", verbose: bool | None = None, **qargs): - self.print(*text, end=end, ltype=Log_Type.TEXT, verbose=verbose, **qargs) + return self.on_text(*text, end=end, verbose=verbose, **qargs) class Logger(Logger_Interface): diff --git a/TPTBox/registration/_deepali/deepali_model.py b/TPTBox/registration/_deepali/deepali_model.py index 7a2d113..b535aa3 100644 --- a/TPTBox/registration/_deepali/deepali_model.py +++ b/TPTBox/registration/_deepali/deepali_model.py @@ -234,7 +234,7 @@ def __init__( source_landmarks=source_landmarks, target_landmarks=target_landmarks, device=device, - target_mask=to_nii(fixed_mask, True).to_deepali() if fixed_mask is not None else None, + target_mask=to_nii(fixed_mask, True).resample_from_to(fix, verbose=False).to_deepali() if fixed_mask is not None else None, source_mask=to_nii(moving_mask, True).to_deepali() if moving_mask is not None else None, normalize_strategy=normalize_strategy, pyramid_levels=pyramid_levels, diff --git a/TPTBox/registration/_deepali/deepali_trainer.py b/TPTBox/registration/_deepali/deepali_trainer.py index 4d31028..8a44f05 100644 --- a/TPTBox/registration/_deepali/deepali_trainer.py +++ b/TPTBox/registration/_deepali/deepali_trainer.py @@ -466,7 +466,8 @@ def _weighted_sum(self, losses: dict[str, Tensor], level) -> Tensor: w = weights.get(name, 1.0) if isinstance(w, (list, tuple)): w = w[level] - + if torch.isnan(value): + value = torch.zeros_like(value) # noqa: PLW2901 if not isinstance(w, str): value = w * value # noqa: PLW2901 loss += value.sum() @@ -488,6 +489,7 @@ def on_loss( # noqa: C901 # Transform target grid points x = target.grid().coords(device=target_data.device).unsqueeze(0) # grid_points y: Tensor = grid_transform(x) + del x assert len(self.loss_terms) != 0, "No losses defined" ### Sum of pairwise image dissimilarity terms ### if self.loss_pairwise_image_terms: @@ -496,13 +498,15 @@ def on_loss( # noqa: C901 # TODO this is from the reference implantation but is need way to much GPU... moved_mask = self._sample_image(y, self.source_mask) mask = overlap_mask(moved_mask, self.target_mask) + elif self.target_mask is not None: + mask = self.target_mask.float() else: mask = None for name, term in self.loss_pairwise_image_terms.items(): losses[name] = term(moved_data, target_data, mask=mask) - result["source"] = moved_data - result["target"] = target_data - result["mask"] = mask + del moved_data + del target_data + del mask if self.loss_pairwise_image_terms2: assert source_image_seg is not None, "Source image segmentation is required" moved_data: torch.Tensor = self._sample_image(y, source_image_seg.tensor()) @@ -510,14 +514,17 @@ def on_loss( # noqa: C901 if self.source_mask is not None and self.target_mask is not None: # TODO this is from the reference implantation but is need way to much GPU... moved_mask = self._sample_image(y, self.source_mask) - mask = overlap_mask(moved_mask, self.target_mask) + mask = overlap_mask(moved_mask, self.target_mask).unsqueeze(0) + elif self.target_mask is not None: + mask = self.target_mask.unsqueeze(0).float() # Masked MSE needs float32, or it becomes 0 or NaN else: mask = None for name, term in self.loss_pairwise_image_terms2.items(): losses[name] = term(moved_data.unsqueeze(0), target_data_seg.unsqueeze(0), mask=mask) # DICE - result["source"] = moved_data - result["target"] = target_data - result["mask"] = mask + del moved_data + del target_data_seg + del mask + del y ## Sum of pairwise point set distance terms if self.loss_dist_terms: if self.source_pset is None: @@ -546,7 +553,7 @@ def on_loss( # noqa: C901 if buf.requires_grad: var = name.rsplit(".", 1)[-1] variables[var].append(buf) - variables["w"] = [U.move_dim(y - x, -1, 1)] + variables["w"] = [U.move_dim(y - x, -1, 1)] # noqa: F821 for name, term in self.disp_terms.items(): match = RE_TERM_VAR.match(name) if match: @@ -564,6 +571,7 @@ def on_loss( # noqa: C901 for buf in bufs: value += term(buf) losses[name] = value + del variables ### Sum of free-form deformation loss terms ### for name, term in self.loss_bspline_terms.items(): value = torch.tensor(0, dtype=torch.float, device=self.device) @@ -614,10 +622,11 @@ def on_step( with torch.no_grad(): for hook in self._eval_hooks.values(): hook(self, num_steps, 1, result) + with torch.no_grad(): for hook in self._step_hooks.values(): hook(self, num_steps, 1, loss.item()) - return loss + return loss.detach().cpu() def _run_level( self, @@ -926,6 +935,6 @@ def run(self): if self.verbose > 3: print(f"Registered images in {timer() - start_reg:.3f}s") if self.verbose > 0: - print() + print(" !") self.target_mask = target_mask return grid_transform diff --git a/TPTBox/registration/_deformable/deformable_reg.py b/TPTBox/registration/_deformable/deformable_reg.py index 377ef06..f21dbdd 100644 --- a/TPTBox/registration/_deformable/deformable_reg.py +++ b/TPTBox/registration/_deformable/deformable_reg.py @@ -79,6 +79,14 @@ def __init__( ): if transform_args is None: transform_args = {"stride": [stride, stride, stride], "transpose": False} + if "transpose" in transform_args and transform_name in [ + "StationaryVelocityFieldTransform", + "SVF", + "SVField", + "DenseVectorFieldTransform", + ]: + transform_args.pop("transpose") + if loss_terms is None: loss_terms = { "be": BSplineBending(stride=1), diff --git a/TPTBox/registration/_deformable/multilabel_segmentation.py b/TPTBox/registration/_deformable/multilabel_segmentation.py index a92d199..878b0d0 100644 --- a/TPTBox/registration/_deformable/multilabel_segmentation.py +++ b/TPTBox/registration/_deformable/multilabel_segmentation.py @@ -157,7 +157,7 @@ def __init__( # noqa: C901 # --- try crop first --- t_crop = target_seg.compute_crop(0, crop_pad_size) - cropped = target_seg.apply_crop(t_crop) + cropped = target_seg.apply_crop(t_crop).apply_pad(crop_pad_size - 50 // 4) if any(c < o for c, o in zip(cropped.shape, target_seg.shape)): resize_mode = "crop" @@ -165,7 +165,7 @@ def __init__( # noqa: C901 target_tmp = cropped else: # --- fallback to padding --- - crop_pad_size = 0 + crop_pad_size = crop_pad_size // 2 target_tmp = target_seg resize_mode = "pad" else: @@ -308,9 +308,10 @@ def load_(cls, w): self.target_grid, self.crop, ) = x + return self - def transform_nii(self, nii_atlas: NII, allow_only_same_grid_as_moving=True): + def transform_nii(self, nii_atlas: NII, allow_only_same_grid_as_moving=True, only_rigid=False): """ Apply both rigid and deformable registration to a new NII object. @@ -322,6 +323,9 @@ def transform_nii(self, nii_atlas: NII, allow_only_same_grid_as_moving=True): """ nii_atlas = self.reg_point.transform_nii(nii_atlas, allow_only_same_grid_as_moving=allow_only_same_grid_as_moving) + if only_rigid: + return nii_atlas + nii_atlas = nii_atlas.apply_crop(self.crop) nii_reg = self.reg_deform.transform_nii(nii_atlas) if nii_reg.seg: @@ -377,3 +381,52 @@ def transform_poi(self, poi_atlas: POI_Global | POI): else: raise ValueError(axis) return poi_reg_flip + + def transform_poi_inverse(self, poi_target: POI_Global | POI): + """ + Transform POIs from target space back into atlas space. + + Args: + poi_target (POI_Global | POI): POIs defined in target space. + + Returns: + POI: POIs mapped back into atlas space. + """ + + poi = poi_target.copy() + + # --- undo left/right flip if needed --- + if not self.same_side: + poi_flip = poi.make_empty_POI() + axis = poi.get_axis("R") + + for k1, k2, (x, y, z) in poi.copy().items(): + if axis == 0: + poi_flip[k1, k2] = (poi.shape[0] - 1 - x, y, z) + elif axis == 1: + poi_flip[k1, k2] = (x, poi.shape[1] - 1 - y, z) + elif axis == 2: + poi_flip[k1, k2] = (x, y, poi.shape[2] - 1 - z) + else: + raise ValueError(axis) + + poi = poi_flip + + # --- resample into deformable registration grid --- + poi = poi.resample_from_to(self.target_grid) + + # --- inverse deformable registration --- + reg_deform_inv = self.reg_deform.inverse() + poi = reg_deform_inv.transform_poi(poi) + + # --- undo crop --- + # if self.crop is not None: + # poi = poi.apply_crop_inverse(self.crop) + + # --- inverse rigid point registration --- + poi = self.reg_point.transform_poi_inverse(poi, allow_only_same_grid_as_moving=False) + + # --- back to atlas grid --- + poi = poi.resample_from_to(self.atlas_org) + + return poi diff --git a/TPTBox/registration/_ridged_points/point_registration.py b/TPTBox/registration/_ridged_points/point_registration.py index c5b65c2..102234e 100644 --- a/TPTBox/registration/_ridged_points/point_registration.py +++ b/TPTBox/registration/_ridged_points/point_registration.py @@ -167,6 +167,23 @@ def transform_poi(self, poi_moving: POI, allow_only_same_grid_as_moving=True, ou poi = poi.resample_from_to(output_space) return poi + def transform_poi_inverse(self, poi_moving: POI, allow_only_same_grid_as_moving=True, output_space=None): + # output_space: POI | NII | None = None, + if allow_only_same_grid_as_moving: + text = "input image must be in the same space as moving. If you sure that this input is in same space as the moving image you can turn of 'only_allow_grid_as_moving'" + poi_moving.assert_affine(self.out_poi, text=text) + move_l = [] + keys = [] + out = dict(zip_strict(keys, move_l)) + + for key, key2, (x, y, z) in poi_moving.items(): + out[key, key2] = self.transform_cord_inverse((x, y, z)) + + poi = self.out_poi.make_empty_POI(out) + if output_space is not None: + poi = poi.resample_from_to(output_space) + return poi + def transform_cord(self, cord: tuple[float, ...], out: sitk.Image | None = None): if out is None: out = self._img_fixed diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index 9040c5a..bfa8736 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -5,7 +5,6 @@ from collections.abc import Sequence from pathlib import Path from typing import Literal -from warnings import warn import numpy as np import torch @@ -18,7 +17,7 @@ _model_path_ = out_base / "nnUNet_results" -def get_ds_info(idx, _model_path: str | Path | None = None, exit_one_fail=True) -> dict: +def get_ds_info(idx, _model_path: str | Path | None = None, exit_one_fail=True, logger=logger) -> dict: if _model_path is not None: _model_path = Path(_model_path) model_path = _model_path / "nnUNet_results" @@ -32,7 +31,7 @@ def get_ds_info(idx, _model_path: str | Path | None = None, exit_one_fail=True) nnunet_path = next(next(iter(model_path.glob(f"*{idx}*"))).glob("*__nnUNet*ResEnc*")) except StopIteration: if exit_one_fail: - Print_Logger().print(f"Please add Dataset {idx} to {model_path}", Log_Type.FAIL) + logger.print(f"Please add Dataset {idx} to {model_path}", Log_Type.FAIL) model_path.mkdir(exist_ok=True, parents=True) sys.exit() else: @@ -72,7 +71,12 @@ def run_inference_on_file( memory_max=160000, # in MB, default is 160GB wait_till_gpu_percent_is_free=0.1, verbose=True, + auto_download=False, + _key_ResEnc="__nnUNet*ResEnc", + logger=logger, ) -> tuple[Image_Reference, np.ndarray | None]: + if model_path is None: + auto_download = True if model_path is not None: model_path = Path(model_path) if model_path.name != "nnUNet_results": @@ -89,11 +93,16 @@ def run_inference_on_file( ) if isinstance(idx, int): - download_weights(idx, model_path) + if auto_download: + download_weights(idx, model_path) try: - nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNet*ResEnc*")) - except StopIteration: - nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNetPlans*")) + nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob(f"*{_key_ResEnc}*")) + except StopIteration as e: + try: + nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNetPlans*")) + except StopIteration: + logger.on_fail(model_path, (f"*{idx:03}*")) + raise e from None else: nnunet_path = Path(idx) assert nnunet_path.exists(), nnunet_path @@ -104,7 +113,7 @@ def run_inference_on_file( # if idx in _unets: # nnunet = _unets[idx] # else: - print("load model", nnunet_path, "; folds", folds) if verbose else None + logger.print("load model", nnunet_path, "; folds", folds) if verbose else None with open(Path(nnunet_path, "plans.json")) as f: plans_info = json.load(f) with open(Path(nnunet_path, "dataset.json")) as f: @@ -117,6 +126,8 @@ def run_inference_on_file( ds_info["orientation"] = ds_info2["model_expected_orientation"] if "resolution_range" in ds_info2: ds_info["resolution_range"] = ds_info2["resolution_range"] + if "labels" in ds_info2: + ds_info["labels_mapping"] = ds_info2["labels"] nnunet = load_inf_model( nnunet_path, @@ -170,14 +181,14 @@ def run_inference_on_file( nnunet_path, ) if orientation is not None: - print("orientation", orientation, f"from {input_nii[0].orientation}") if verbose else None + logger.print("orientation", orientation, f"from {input_nii[0].orientation}") if verbose else None input_nii = [i.reorient(orientation) for i in input_nii] if zoom is not None: - print("rescale", f"{zoom=} from {input_nii[0].zoom}") if verbose else None + logger.print("rescale", f"{zoom=} from {input_nii[0].zoom}") if verbose else None input_nii = [i.rescale_(zoom, mode=mode, verbose=True) for i in input_nii] - print(input_nii) - print("squash to float16") if verbose else None + logger.print(input_nii) + logger.print("squash to float16") if verbose else None input_nii = [squash_so_it_fits_in_float16(i) for i in input_nii] if crop: @@ -197,6 +208,44 @@ def run_inference_on_file( seg_nii.resample_from_to_(og_nii, mode=mode) if fill_holes: seg_nii.fill_holes_() + if "labels_mapping" in ds_info: + from TPTBox.core.vert_constants import list_of_all_enums + + mapping_ = ds_info["labels_mapping"] + unknown_strings: dict[str, int] = {"max": seg_nii.max() + 1, "Intervertebral_Disc": 100} + mapping = {} + + def to_int(a: str, k: None | int = None): + if a in unknown_strings: + return unknown_strings[a] + try: + return int(a) + except Exception: + pass + + for enum_ in list_of_all_enums: + try: + return enum_[a].value + except Exception: + print("no ", enum_) + if k is not None and k not in unknown_strings.values(): + return k + unknown_strings[a] = unknown_strings["max"] + unknown_strings["max"] += 1 + if unknown_strings["max"] == 100: + unknown_strings["max"] += 1 + + return unknown_strings[a] + + for k, v in mapping_.items(): + key = to_int(k) + value = to_int(v, key) + if k != value: + mapping[k] = value + unknown_strings[v] = value + logger.print(f"{unknown_strings}") + logger.print(f"{mapping=}") + seg_nii.map_labels_(mapping) if out_file is not None and (not Path(out_file).exists() or override): seg_nii.save(out_file) del nnunet @@ -222,6 +271,7 @@ def run_VibeSeg( max_folds: int | None = None, _model_path=None, step_size=0.5, + logger: Print_Logger = logger, **_kargs, ): if isinstance(out_path, str): @@ -248,12 +298,12 @@ def run_VibeSeg( return else: weights_dir = download_weights(dataset_id) - print("to", weights_dir) + logger.print("to", weights_dir) selected_gpu = gpu if gpu is None: gpu = "auto" # type: ignore logger.print("run", f"{dataset_id=}, {gpu=}", Log_Type.STAGE) - ds_info = get_ds_info(dataset_id) + ds_info = get_ds_info(dataset_id, logger=logger) orientation = ds_info.get("orientation", ("R", "A", "S")) if not isinstance(img, Sequence) or isinstance(img, str): img = [img] @@ -264,7 +314,7 @@ def run_VibeSeg( in_niis = [to_nii(i) for i in img] # type: ignore in_niis = [i.resample_from_to_(in_niis[0]) if i.shape != in_niis[0].shape else i for i in in_niis] if (in_niis[0].affine == np.eye(4)).all(): - warn( + logger.on_warning( "Your affine matrix is the identity. Make sure that the spacing and orientation is correct. For NAKO VIBE it should be 1.40625 mm for R/L and A/P and 3 mm S/I. For UKBB R/L and A/P should be around 2.2 mm", stacklevel=3, ) @@ -281,5 +331,6 @@ def run_VibeSeg( crop=crop, max_folds=max_folds, step_size=step_size, + logger=logger, **_kargs, )[0] diff --git a/TPTBox/segmentation/VibeSeg/vibeseg.py b/TPTBox/segmentation/VibeSeg/vibeseg.py index 9010385..7e05924 100644 --- a/TPTBox/segmentation/VibeSeg/vibeseg.py +++ b/TPTBox/segmentation/VibeSeg/vibeseg.py @@ -89,7 +89,7 @@ def run_vibeseg( gpu=0, ddevice: Literal["cpu", "cuda", "mps"] = "cuda", dataset_id=100, - padd=0, + padd=5, keep_size=False, # Keep size of the model Segmentation **args, ): @@ -109,10 +109,22 @@ def run_vibeseg( def run_nnunet( i: list[Image_Reference], out_seg: str | Path, + *, override=False, gpu=0, ddevice: Literal["cpu", "cuda", "mps"] = "cuda", dataset_id=80, + model_path: str | Path | None = None, + auto_download=False, # set to True if model_path is None + keep_size=False, + fill_holes=False, + logits=False, + mapping=None, + crop=False, + max_folds=None, + mode="nearest", + padd: int = 0, + key_ResEnc="__nnUNet*ResEnc", **args, ): run_inference_on_file( @@ -122,6 +134,17 @@ def run_nnunet( override=override, gpu=gpu, ddevice=ddevice, + model_path=model_path, + auto_download=auto_download, + keep_size=keep_size, + fill_holes=fill_holes, + logits=logits, + mapping=mapping, + crop=crop, + max_folds=max_folds, + mode=mode, + padd=padd, + _key_ResEnc=key_ResEnc, **args, ) diff --git a/TPTBox/segmentation/_deface.py b/TPTBox/segmentation/_deface.py index f97d75d..79ea14a 100644 --- a/TPTBox/segmentation/_deface.py +++ b/TPTBox/segmentation/_deface.py @@ -93,13 +93,13 @@ def compute_deface_mask_cta( outpath = ct_img.get_changed_path("nii.gz", "msk", parent="derivatives-defacing", info={"seg": "defacting-2", "mod": ct_img.format}) if outpath is not None and not override and outpath.exists(): return outpath - tight_mask = compute_deface_mask_cta(ct_img, None, override=override, gpu=gpu, **args) + tight_mask = _compute_deface_mask_cta(ct_img, None, override=override, gpu=gpu, **args) ct = to_nii(ct_img, False) face = to_nii(tight_mask, True).resample_from_to(ct) if not partially_defaced: face = face.filter_connected_components(max_count_component=1, keep_label=True) face_org = face.copy() - f2 = face.dilate_msk(6).smooth_gaussian_labelwise(1, 5) + f2 = face.dilate_msk(3).smooth_gaussian_labelwise(1, 5) f2[ct > -600] = 0 f2 = f2.filter_connected_components(max_count_component=1, keep_label=True) @@ -108,7 +108,7 @@ def compute_deface_mask_cta( m2 = mask.extract_label(1) mask = mask.clamp(0, 1).fill_holes(1, "S") - mask *= mask.clamp(0, 1).erode_msk(4) + mask *= mask.clamp(0, 1).erode_msk(2) mask[m2 * mask] = 1 if outpath is not None: mask.save(outpath) diff --git a/TPTBox/spine/spinestats/body_quadrants.py b/TPTBox/spine/spinestats/body_quadrants.py index 60b9596..2ae8e4e 100644 --- a/TPTBox/spine/spinestats/body_quadrants.py +++ b/TPTBox/spine/spinestats/body_quadrants.py @@ -13,6 +13,7 @@ def make_quadrants( poi_buffer: str | Path | None = None, vert_ids: list[int] | None = None, mask_ids=(49, 50, 52), + erode=0, ): """ Subdivide vertebral body masks into anatomically oriented 3×3×3 regions. @@ -103,8 +104,12 @@ def make_quadrants( buffer_file=poi_buffer, ) out_nii = vert_nii * 0 + mask_nii = spine_nii.extract_label(mask_ids) + if erode != 0: + mask_nii = mask_nii.erode_msk(erode) + for v_id in vert_nii.unique() if vert_ids is None else vert_ids: - v21 = vert_nii.extract_label(v_id) * spine_nii.extract_label(mask_ids) + v21 = vert_nii.extract_label(v_id) * mask_nii try: # POIs center = np.array(poi[v_id, Location.Vertebra_Corpus])