Source code for aicssegmentation.core.utils

from typing import List

import numpy as np
from scipy.ndimage import distance_transform_edt
from skimage.measure import label, regionprops
from skimage.morphology import ball, disk, dilation, erosion, medial_axis, remove_small_objects


[docs]def hole_filling(bw: np.ndarray, hole_min: int, hole_max: int, fill_2d: bool = True) -> np.ndarray: """Fill holes in 2D/3D segmentation Parameters: ------------- bw: np.ndarray a binary 2D/3D image. hole_min: int the minimum size of the holes to be filled hole_max: int the maximum size of the holes to be filled fill_2d: bool if fill_2d=True, a 3D image will be filled slice by slice. If you think of a hollow tube alone z direction, the inside is not a hole under 3D topology, but the inside on each slice is indeed a hole under 2D topology. Return: a binary image after hole filling """ bw = bw > 0 if len(bw.shape) == 2: background_lab = label(~bw, connectivity=1) fill_out = np.copy(background_lab) component_sizes = np.bincount(background_lab.ravel()) too_big = component_sizes > hole_max too_big_mask = too_big[background_lab] fill_out[too_big_mask] = 0 too_small = component_sizes < hole_min too_small_mask = too_small[background_lab] fill_out[too_small_mask] = 0 elif len(bw.shape) == 3: if fill_2d: fill_out = np.zeros_like(bw) for zz in range(bw.shape[0]): background_lab = label(~bw[zz, :, :], connectivity=1) out = np.copy(background_lab) component_sizes = np.bincount(background_lab.ravel()) too_big = component_sizes > hole_max too_big_mask = too_big[background_lab] out[too_big_mask] = 0 too_small = component_sizes < hole_min too_small_mask = too_small[background_lab] out[too_small_mask] = 0 fill_out[zz, :, :] = out else: background_lab = label(~bw, connectivity=1) fill_out = np.copy(background_lab) component_sizes = np.bincount(background_lab.ravel()) too_big = component_sizes > hole_max too_big_mask = too_big[background_lab] fill_out[too_big_mask] = 0 too_small = component_sizes < hole_min too_small_mask = too_small[background_lab] fill_out[too_small_mask] = 0 else: print("error in image shape") return return np.logical_or(bw, fill_out)
[docs]def size_filter(img: np.ndarray, min_size: int, method: str = "3D", connectivity: int = 1): """size filter Parameters: ------------ img: np.ndarray the image to filter on min_size: int the minimum size to keep method: str either "3D" or "slice_by_slice", default is "3D" connnectivity: int the connectivity to use when computing object size """ assert len(img.shape) == 3, "image has to be 3D" if method == "3D": return remove_small_objects(img > 0, min_size=min_size, connectivity=connectivity) elif method == "slice_by_slice": seg = np.zeros(img.shape, dtype=bool) for zz in range(img.shape[0]): seg[zz, :, :] = remove_small_objects( img[zz, :, :] > 0, min_size=min_size, connectivity=connectivity, ) return seg else: raise NotImplementedError(f"unsupported method {method}")
[docs]def topology_preserving_thinning(bw: np.ndarray, min_thickness: int = 1, thin: int = 1) -> np.ndarray: """perform thinning on segmentation without breaking topology Parameters: -------------- bw: np.ndarray the 3D binary image to be thinned min_thickness: int Half of the minimum width you want to keep from being thinned. For example, when the object width is smaller than 4, you don't want to make this part even thinner (may break the thin object and alter the topology), you can set this value as 2. thin: int the amount to thin (has to be an positive integer). The number of pixels to be removed from outter boundary towards center. Return: ------------- A binary image after thinning """ bw = bw > 0 safe_zone = np.zeros_like(bw) for zz in range(bw.shape[0]): if np.any(bw[zz, :, :]): ctl = medial_axis(bw[zz, :, :] > 0) dist = distance_transform_edt(ctl == 0) safe_zone[zz, :, :] = dist > min_thickness + 1e-5 rm_candidate = np.logical_xor(bw > 0, erosion(bw > 0, ball(thin))) bw[np.logical_and(safe_zone, rm_candidate)] = 0 return bw
[docs]def divide_nonzero(array1, array2): """ Divides two arrays. Returns zero when dividing by zero. """ denominator = np.copy(array2) denominator[denominator == 0] = 1e-10 return np.divide(array1, denominator)
[docs]def histogram_otsu(hist): """Apply Otsu thresholding method on 1D histogram""" # modify the elements in hist to avoid completely zero value in cumsum hist = hist + 1e-5 bin_size = 1 / (len(hist) - 1) bin_centers = np.arange(0, 1 + 0.5 * bin_size, bin_size) hist = hist.astype(float) # class probabilities for all possible thresholds weight1 = np.cumsum(hist) weight2 = np.cumsum(hist[::-1])[::-1] # class means for all possible thresholds mean1 = np.cumsum(hist * bin_centers) / weight1 mean2 = (np.cumsum((hist * bin_centers)[::-1]) / weight2[::-1])[::-1] # Clip ends to align class 1 and class 2 variables: # The last value of `weight1`/`mean1` should pair with zero values in # `weight2`/`mean2`, which do not exist. variance12 = weight1[:-1] * weight2[1:] * (mean1[:-1] - mean2[1:]) ** 2 idx = np.argmax(variance12) threshold = bin_centers[:-1][idx] return threshold
[docs]def absolute_eigenvaluesh(nd_array): """Computes the eigenvalues sorted by absolute value from the symmetrical matrix. Parameters: ------------- nd_array: nd.ndarray array from which the eigenvalues will be calculated. Return: ------------- A list with the eigenvalues sorted in absolute ascending order (e.g. [eigenvalue1, eigenvalue2, ...]) """ eigenvalues = np.linalg.eigvalsh(nd_array) sorted_eigenvalues = sortbyabs(eigenvalues, axis=-1) return [ np.squeeze(eigenvalue, axis=-1) for eigenvalue in np.split(sorted_eigenvalues, sorted_eigenvalues.shape[-1], axis=-1) ]
[docs]def sortbyabs(a: np.ndarray, axis=0): """Sort array along a given axis by the absolute value modified from: http://stackoverflow.com/a/11253931/4067734 """ index = list(np.ix_(*[np.arange(i) for i in a.shape])) index[axis] = np.abs(a).argsort(axis) return a[tuple(index)]
[docs]def get_middle_frame(struct_img: np.ndarray, method: str = "z") -> int: """find the middle z frame of an image stack Parameters: ------------ struct_img: np.ndarray the 3D image to process method: str which method to use to determine the middle frame. Options are "z" or "intensity". "z" is solely based on the number of z frames. "intensity" method uses Otsu threshod to estimate the volume of foreground signals in the stack, then estimated volume of each z frame forms a z-profile, and finally another Otsu method is apply on the z profile to find the best z frame (with an assumption of two peaks along z profile, one near the bottom of the cells and one near the bottom of the cells, so the optimal separation is the middle of the stack). Return: ----------- mid_frame: int the z index of the middle z frame """ from skimage.filters import threshold_otsu if method == "intensity": bw = struct_img > threshold_otsu(struct_img) z_profile = np.zeros((bw.shape[0],), dtype=int) for zz in range(bw.shape[0]): z_profile[zz] = np.count_nonzero(bw[zz, :, :]) mid_frame = None if isinstance(round(histogram_otsu(z_profile) * bw.shape[0]), int): mid_frame = round(histogram_otsu(z_profile) * bw.shape[0]) else: mid_frame = round(histogram_otsu(z_profile) * bw.shape[0]).astype(int) elif method == "z": mid_frame = struct_img.shape[0] // 2 else: print("unsupported method") quit() return mid_frame
[docs]def get_3dseed_from_mid_frame( bw: np.ndarray, stack_shape: List = None, mid_frame: int = -1, hole_min: int = 1, bg_seed: bool = True, ) -> np.ndarray: """build a 3D seed image from the binary segmentation of a single slice Parameters: ------------ bw: np.ndarray the 2d segmentation of a single frame, or a 3D array with only one slice containing segmentation stack_shape: List (only used when bw is 2d) the shape of original 3d image, e.g. shape_3d = img.shape frame_index: int (only used when bw is 2d) the index of where bw is from the whole z-stack hole_min: int any connected component in bw2d with size smaller than area_min will be excluded from seed image generation bg_seed: bool bg_seed=True will add a background seed at the first frame (z=0). """ from skimage.morphology import remove_small_objects out = remove_small_objects(bw > 0, hole_min) out1 = label(out) stat = regionprops(out1) # build the seed seed = np.zeros(stack_shape) seed_count = 0 if bg_seed: seed[0, :, :] = 1 seed_count += 1 for idx in range(len(stat)): py, px = np.round(stat[idx].centroid) seed_count += 1 seed[mid_frame, int(py), int(px)] = seed_count return seed
[docs]def remove_hot_pixel(seg: np.ndarray) -> np.ndarray: """ remove hot pixel from segmentation """ assert len(seg.shape) == 3, "input segmentation must be 3D" # make sure the segmentation is 0/1 seg = seg.astype(np.uint8) seg[seg > 0] = 1 # get sum projection along z seg_proj = np.sum(seg, axis=0) # find hot pixels hot_pixel = seg_proj >= seg.shape[0] - 2 # dilate the area to cover the surrounding pixels hot_pixel = dilation(hot_pixel, disk(2)) # clean up every z for z in range(seg.shape[0]): seg_z = seg[z, :, :] seg_z[hot_pixel] = 0 seg[z, :, :] = seg_z return seg
[docs]def get_seed_for_objects( raw: np.ndarray, bw: np.ndarray, area_min: int = 1, area_max: int = 10000, bg_seed: bool = True, ) -> np.ndarray: """ build a seed image for an image of 3D objects (assuming roughly convex shape in 3D) using the information in the middle slice Parameters: ------------ raw: np.ndarray orignal image used to determine middle slice bw: np.ndarray a round 3D segmentation, expecting the segmentation in the middle slice having relatively good quality area_min: int estimated minimal size on one single slice (major body chunk, e.g. the center XY plane of a 3D ball) of an object area_max: int estimated maximal size on one single slice (major body chunk, e.g. the center XY plane of a 3D ball) of an object. It is recommended to be conservertive (setting this value a little larger) bg_seed: bool bg_seed=True will add a background seed at the first frame (z=0). """ from skimage.morphology import remove_small_objects # determine middle slice mid_z = get_middle_frame(raw, method="intensity") # take seg of middle slice bw2d = bw[mid_z, :, :] # fillin holes to form solid objects bw2d_fill = hole_filling(bw2d, area_min, area_max) # prune the objects in middle slice out = remove_small_objects(bw2d_fill > 0, area_min) # extract object and calculate centroid out1 = label(out) stat = regionprops(out1) # use each centroid as one seed seed = np.zeros(raw.shape) seed_count = 0 if bg_seed: seed[0, :, :] = 1 seed_count += 1 for idx in range(len(stat)): py, px = np.round(stat[idx].centroid) seed_count += 1 seed[mid_z, int(py), int(px)] = seed_count return seed.astype(int)
[docs]def segmentation_union(seg: List) -> np.ndarray: """merge multiple segmentations into a single result Parameters ------------ seg: List a list of segmentations, should all have the same shape """ return np.logical_or.reduce(seg)
[docs]def segmentation_intersection(seg: List) -> np.ndarray: """get the intersection of multiple segmentations into a single result Parameters ------------ seg: List a list of segmentations, should all have the same shape """ return np.logical_and.reduce(seg)
[docs]def segmentation_xor(seg: List) -> np.ndarray: """get the XOR of multiple segmentations into a single result Parameters ------------ seg: List a list of segmentations, should all have the same shape """ return np.logical_xor.reduce(seg)
[docs]def remove_index_object(label: np.ndarray, id_to_remove: List[int] = [1], in_place: bool = False) -> np.ndarray: if in_place: img = label else: img = label.copy() for id in id_to_remove: img[img == id] = 0 return img
[docs]def peak_local_max_wrapper(struct_img_for_peak: np.ndarray, bw: np.ndarray) -> np.ndarray: from skimage.feature import peak_local_max local_maxi = peak_local_max(struct_img_for_peak, labels=label(bw), min_distance=2) local_maxi_image = np.zeros_like(struct_img_for_peak) local_maxi_image[tuple(local_maxi.T)] = True return local_maxi_image
[docs]def watershed_wrapper(bw: np.ndarray, local_maxi: np.ndarray) -> np.ndarray: from scipy.ndimage import distance_transform_edt from skimage.measure import label from skimage.morphology import dilation, ball from skimage.segmentation import watershed distance = distance_transform_edt(bw) im_watershed = watershed( -distance, label(dilation(local_maxi, footprint=ball(1))), mask=bw, watershed_line=True, ) return im_watershed
[docs]def prune_z_slices(bw: np.ndarray): """ prune the segmentation by only keep a certain range of z-slices with the assumption of all signals living only in a few consecutive z-slices. This function will first determine the key z-slice where most of the signals living on and then include a few slices up/down along z to make the segmentation completed. This is useful when you have prior knowledge about your segmentation target and can effectively exclude small segmented objects due to noise/artifacts in those z-slices we are sure the signal should not live on. Parameters: ----------- bw: np.ndarray the segmentation before pruning """ bw_z = np.zeros(bw.shape[0], dtype=np.uint16) for zz in range(bw.shape[0]): bw_z[zz] = np.count_nonzero(bw[zz, :, :] > 0) mid_z = np.argmax(bw_z) low_z = 0 high_z = bw.shape[0] - 2 for ii in np.arange(mid_z - 1, 0, -1): if bw_z[ii] < 100: low_z = ii break for ii in range(mid_z + 1, bw.shape[0] - 1, 1): if bw_z[ii] < 100: high_z = ii break seg = bw.copy() seg[:low_z, :, :] = 0 seg[high_z + 1 :, :, :] = 0 return seg
[docs]def cell_local_adaptive_threshold(structure_img_smooth: np.ndarray, cell_wise_min_area: int): from skimage.filters import threshold_triangle, threshold_otsu from skimage.morphology import dilation # cell-wise local adaptive thresholding th_low_level = threshold_triangle(structure_img_smooth) bw_low_level = structure_img_smooth > th_low_level bw_low_level = remove_small_objects(bw_low_level, min_size=cell_wise_min_area, connectivity=1, out=bw_low_level) bw_low_level = dilation(bw_low_level, footprint=ball(2)) bw_high_level = np.zeros_like(bw_low_level) lab_low, num_obj = label(bw_low_level, return_num=True, connectivity=1) for idx in range(num_obj): single_obj = lab_low == (idx + 1) local_otsu = threshold_otsu(structure_img_smooth[single_obj > 0]) bw_high_level[np.logical_and(structure_img_smooth > local_otsu * 0.98, single_obj)] = 1 return bw_high_level
[docs]def invert_mask(img): return 1 - img
[docs]def mask_image(image, mask, value: int = 0): image[mask] = value return image