import vtk
import pyshtools
import numpy as np
from typing import Tuple, List
from scipy import stats as scistats
from skimage import transform as sktrans
from skimage import filters as skfilters
from skimage import morphology as skmorpho
from scipy import interpolate as sciinterp
from vtk.util import numpy_support as vtknp
from sklearn import decomposition as skdecomp
EPS = 1e-12
[docs]
def get_mesh_from_image(
image: np.array,
sigma: float = 0,
lcc: bool = True,
translate_to_origin: bool = True,
):
"""Converts a numpy array into a vtkImageData and then into a 3d mesh
using vtkContourFilter. The input is assumed to be binary and the
isosurface value is set to 0.5.
Optionally the input can be pre-processed by i) extracting the largest
connected component and ii) applying a gaussian smooth to it. In case
smooth is used, the image is binarize using thershold 1/e.
A size threshold is applying to garantee that enough points will be
used to compute the SH parametrization.
Also, points as the edge of the image are set to zero (background)
to make sure the mesh forms a manifold.
Parameters
----------
image : np.array
Input array where the mesh will be computed on
Returns
-------
mesh : vtkPolyData
3d mesh in VTK format
img_output : np.array
Input image after pre-processing
centroid : np.array
x, y, z coordinates of the mesh centroid
Other parameters
----------------
lcc : bool, optional
Wheather or not to compute the mesh only on the largest
connected component found in the input connected component,
default is True.
sigma : float, optional
The degree of smooth to be applied to the input image, default
is 0 (no smooth).
translate_to_origin : bool, optional
Wheather or not translate the mesh to the origin (0,0,0),
default is True.
"""
img = image.copy()
# VTK requires YXZ
img = np.swapaxes(img, 0, 2)
# Extracting the largest connected component
if lcc is True:
img = skmorpho.label(img.astype(np.uint8))
counts = np.bincount(img.flatten())
lcc = 1 + np.argmax(counts[1:])
img[img != lcc] = 0
img[img == lcc] = 1
# Smooth binarize the input image and binarize
if sigma > 0:
img = skfilters.gaussian(img.astype(np.float32), sigma=(sigma, sigma, sigma))
img[img < 1.0 / np.exp(1.0)] = 0
img[img > 0] = 1
if img.sum() == 0:
raise ValueError(
"No foreground voxels found after pre-processing. Try using sigma=0."
)
# Set image border to 0 so that the mesh forms a manifold
img[[0, -1], :, :] = 0
img[:, [0, -1], :] = 0
img[:, :, [0, -1]] = 0
img = img.astype(np.float32)
if img.sum() == 0:
raise ValueError(
"No foreground voxels found after pre-processing."
"Is the object of interest centered?"
)
# Create vtkImageData
imgdata = vtk.vtkImageData()
imgdata.SetDimensions(img.shape)
img = img.transpose(2, 1, 0)
img_output = img.copy()
img = img.flatten()
arr = vtknp.numpy_to_vtk(img, array_type=vtk.VTK_FLOAT)
arr.SetName("Scalar")
imgdata.GetPointData().SetScalars(arr)
# Create 3d mesh
cf = vtk.vtkContourFilter()
cf.SetInputData(imgdata)
cf.SetValue(0, 0.5)
cf.Update()
mesh = cf.GetOutput()
# Calculate the mesh centroid
coords = vtknp.vtk_to_numpy(mesh.GetPoints().GetData())
centroid = coords.mean(axis=0, keepdims=True)
if translate_to_origin is True:
# Translate to origin
coords -= centroid
mesh.GetPoints().SetData(vtknp.numpy_to_vtk(coords))
return mesh, img_output, tuple(centroid.squeeze())
[docs]
def rotate_image_2d(image: np.array, angle: float, interpolation_order: int = 0):
"""Rotate multichannel image in 2D by a given angle. The
expected shape of image is (C,Z,Y,X). The rotation will
be done clock-wise around the center of the image.
Parameters
----------
angle : float
Angle in degrees
interpolation_order : int
Order of interpolation used during the image rotation
Returns
-------
img_rot : np.array
Rotated image
"""
if image.ndim != 4:
raise ValueError(
f"Invalid shape {image.shape} of input image."
"Expected 4 dimensional images as input."
)
if not isinstance(interpolation_order, int):
raise ValueError("Only integer values are accepted for interpolation order.")
# Make z to be the last axis. Required for skimage rotation
image = np.swapaxes(image, 1, 3)
img_aligned = []
for stack in image:
stack_aligned = sktrans.rotate(
image=stack,
angle=-angle,
resize=True,
order=interpolation_order,
preserve_range=True,
)
img_aligned.append(stack_aligned)
img_aligned = np.array(img_aligned)
# Swap axes back
img_aligned = np.swapaxes(img_aligned, 1, 3)
img_aligned = img_aligned.astype(image.dtype)
return img_aligned
[docs]
def align_image_2d(
image: np.array,
alignment_channel: int = None,
make_unique: bool = False,
compute_aligned_image: bool = True,
):
"""Align a multichannel 3D image based on the channel
specified by alignment_channel. The expected shape of
image is (C,Z,Y,X) or (Z,Y,X).
Parameters
----------
image : np.array
Input array of shape (C,Z,Y,X) or (Z,Y,X).
alignment_channel : int
Number of channel to be used as reference for alignemnt. The
alignment will be propagated to all other channels.
make_unique : bool
Set true to make sure the alignment rotation is unique.
compute_aligned_image : bool
Set false to only compute and return the alignment angle
Returns
-------
img_aligned : np.array
Aligned image
angle : float
Angle used for align the shape.
"""
if image.ndim not in [3, 4]:
raise ValueError(f"Invalid shape {image.shape} of input image.")
if image.ndim == 4:
if alignment_channel is None:
raise ValueError(
"An alignment channel must be provided with multichannel images."
)
if not isinstance(alignment_channel, int):
raise ValueError("Number of alignment channel must be an integer")
if image.ndim == 3:
alignment_channel = 0
image = image.reshape(1, *image.shape)
z, y, x = np.where(image[alignment_channel])
xy = np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)])
pca = skdecomp.PCA(n_components=2)
pca = pca.fit(xy)
eigenvecs = pca.components_
if make_unique is True:
# Calculate angle with arctan2
angle = 180.0 * np.arctan2(eigenvecs[0][1], eigenvecs[0][0]) / np.pi
# Rotate x coordinates
x_rot = (x - x.mean()) * np.cos(np.pi * angle / 180) + (y - y.mean()) * np.sin(
np.pi * angle / 180
)
# Check the skewness of the rotated x coordinate
xsk = scistats.skew(x_rot)
if xsk < 0.0:
angle += 180
# Map all angles to anti-clockwise
angle = angle % 360
else:
# Calculate smallest angle
angle = 0.0
if np.abs(eigenvecs[0][0]) > EPS:
angle = 180.0 * np.arctan(eigenvecs[0][1] / eigenvecs[0][0]) / np.pi
if compute_aligned_image is True:
# Apply skimage rotation clock-wise
img_aligned = rotate_image_2d(image=image, angle=angle)
return img_aligned, angle
return angle
[docs]
def apply_image_alignment_2d(image: np.array, angle: float):
"""Apply an existing set of alignment parameters to a
multichannel 3D image. The expected shape of
image is (C,Z,Y,X) or (Z,Y,X).
Parameters
----------
image : np.array
Input array of shape (C,Z,Y,X) or (Z,Y,X).
angle : float
2D rotation angle in degrees
Returns
-------
img_aligned : np.array
Aligned image
"""
if image.ndim not in [3, 4]:
raise ValueError(f"Invalid shape {image.shape} of input image.")
if image.ndim == 3:
image = image.reshape(1, *image.shape)
img_aligned = rotate_image_2d(image=image, angle=angle)
return img_aligned
[docs]
def update_mesh_points(
mesh: vtk.vtkPolyData, x_new: np.array, y_new: np.array, z_new: np.array
):
"""Updates the xyz coordinates of points in the input mesh
with new coordinates provided.
Parameters
----------
mesh : vtkPolyData
Mesh in VTK format to be updated.
x_new, y_new and z_new : np.array
Array containing the new coordinates.
Returns
-------
mesh_updated : vtkPolyData
Mesh with updated coordinates.
Notes
-----
This function also re-calculate the new normal vectors
for the updated mesh.
"""
mesh.GetPoints().SetData(vtknp.numpy_to_vtk(np.c_[x_new, y_new, z_new]))
mesh.Modified()
# Fix normal vectors orientation
normals = vtk.vtkPolyDataNormals()
normals.SetInputData(mesh)
normals.Update()
mesh_updated = normals.GetOutput()
return mesh_updated
[docs]
def get_even_reconstruction_from_grid(
grid: np.array, npoints: int = 512, centroid: Tuple = (0, 0, 0)
):
"""Converts a parametric 2D grid of type (lon,lat,rad) into
a 3d mesh. lon in [0,2pi], lat in [0,pi]. The method uses
a spherical mesh with an even distribution of points. The
even distribution is obtained via the Fibonacci grid rule.
Parameters
----------
grid : np.array
Input grid where the element grid[i,j] represents the
radial coordinate at longitude i*2pi/grid.shape[0] and
latitude j*pi/grid.shape[1].
Returns
-------
mesh : vtkPolyData
Mesh that represents the input parametric grid.
Other parameters
----------------
npoints: int
Number of points in the initial spherical mesh
centroid : tuple of floats, optional
x, y and z coordinates of the centroid where the mesh
will be translated to, default is (0,0,0).
"""
res_lat = grid.shape[0]
res_lon = grid.shape[1]
# Creates an interpolator
lon = np.linspace(start=0, stop=2 * np.pi, num=res_lon, endpoint=True)
lat = np.linspace(start=0, stop=1 * np.pi, num=res_lat, endpoint=True)
fgrid = sciinterp.RectBivariateSpline(lon, lat, grid.T)
# Create x,y,z coordinates based on the Fibonacci Lattice
# http://extremelearning.com.au/evenly-distributing-points-on-a-sphere/
golden_ratio = 0.5 * (1 + np.sqrt(5))
idxs = np.arange(0, npoints, dtype=np.float32)
fib_theta = np.arccos(2 * ((idxs + 0.5) / npoints) - 1)
fib_phi = (2 * np.pi * (idxs / golden_ratio)) % (2 * np.pi) - np.pi
fib_lat = fib_theta
fib_lon = fib_phi + np.pi
fib_grid = fgrid.ev(fib_lon, fib_lat)
# Assign to sphere
fib_x = centroid[0] + fib_grid * np.sin(fib_theta) * np.cos(fib_phi)
fib_y = centroid[1] + fib_grid * np.sin(fib_theta) * np.sin(fib_phi)
fib_z = centroid[2] + fib_grid * np.cos(fib_theta)
# Add points (x,y,z) to a polydata
points = vtk.vtkPoints()
for x, y, z in zip(fib_x, fib_y, fib_z):
points.InsertNextPoint(x, y, z)
rec = vtk.vtkPolyData()
rec.SetPoints(points)
# Calculates the connections between points via Delaunay triangulation
delaunay = vtk.vtkDelaunay3D()
delaunay.SetInputData(rec)
delaunay.Update()
surface_filter = vtk.vtkDataSetSurfaceFilter()
surface_filter.SetInputData(delaunay.GetOutput())
surface_filter.Update()
# Smooth the mesh to get a more even distribution of points
NITER_SMOOTH = 128
smooth = vtk.vtkSmoothPolyDataFilter()
smooth.SetInputData(surface_filter.GetOutput())
smooth.SetNumberOfIterations(NITER_SMOOTH)
smooth.FeatureEdgeSmoothingOff()
smooth.BoundarySmoothingOn()
smooth.Update()
rec.DeepCopy(smooth.GetOutput())
# Compute normal vectors
normals = vtk.vtkPolyDataNormals()
normals.SetInputData(rec)
normals.Update()
mesh = normals.GetOutput()
return mesh
[docs]
def get_even_reconstruction_from_coeffs(
coeffs: np.array, lrec: int = 0, npoints: int = 512
):
"""Converts a set of spherical harmonic coefficients into
a 3d mesh using the Fibonacci grid for generating a mesh
with a more even distribution of points
Parameters
----------
coeffs : np.array
Input array of spherical harmonic coefficients. These
array has dimensions 2xLxM, where the first dimension
is 0 for cosine-associated coefficients and 1 for
sine-associated coefficients. Second and thrid dimensions
represent the expansion parameters (l,m).
Returns
-------
mesh : vtkPolyData
Mesh that represents the input parametric grid.
Other parameters
----------------
lrec : int, optional
Only coefficients l<lrec will be used for creating the
mesh, default is 0 meaning all coefficients available
in the matrix coefficients will be used.
npoints : int optional
Number of points in the initial spherical mesh
Notes
-----
The mesh resolution is set by the size of the coefficients
matrix and therefore not affected by lrec.
"""
coeffs_ = coeffs.copy()
if (lrec > 0) and (lrec < coeffs_.shape[1]):
coeffs_[:, lrec:, :] = 0
grid = pyshtools.expand.MakeGridDH(coeffs_, sampling=2)
mesh = get_even_reconstruction_from_grid(grid, npoints)
return mesh, grid
[docs]
def get_reconstruction_from_grid(grid: np.array, centroid: Tuple = (0, 0, 0)):
"""Converts a parametric 2D grid of type (lon,lat,rad) into
a 3d mesh. lon in [0,2pi], lat in [0,pi].
Parameters
----------
grid : np.array
Input grid where the element grid[i,j] represents the
radial coordinate at longitude i*2pi/grid.shape[0] and
latitude j*pi/grid.shape[1].
Returns
-------
mesh : vtkPolyData
Mesh that represents the input parametric grid.
Other parameters
----------------
centroid : tuple of floats, optional
x, y and z coordinates of the centroid where the mesh
will be translated to, default is (0,0,0).
"""
res_lat = grid.shape[0]
res_lon = grid.shape[1]
# Creates an initial spherical mesh with right dimensions.
rec = vtk.vtkSphereSource()
rec.SetPhiResolution(res_lat + 2)
rec.SetThetaResolution(res_lon)
rec.Update()
rec = rec.GetOutput()
grid_ = grid.T.flatten()
# Update the points coordinates of the spherical mesh according to the inout grid
for j, lon in enumerate(np.linspace(0, 2 * np.pi, num=res_lon, endpoint=False)):
for i, lat in enumerate(
np.linspace(np.pi / (res_lat + 1), np.pi, num=res_lat, endpoint=False)
):
theta = lat
phi = lon - np.pi
k = j * res_lat + i
x = centroid[0] + grid_[k] * np.sin(theta) * np.cos(phi)
y = centroid[1] + grid_[k] * np.sin(theta) * np.sin(phi)
z = centroid[2] + grid_[k] * np.cos(theta)
rec.GetPoints().SetPoint(k + 2, x, y, z)
# Update coordinates of north and south pole points
north = grid_[::res_lat].mean()
south = grid_[(res_lat - 1) :: res_lat].mean()
rec.GetPoints().SetPoint(0, centroid[0] + 0, centroid[1] + 0, centroid[2] + north)
rec.GetPoints().SetPoint(1, centroid[0] + 0, centroid[1] + 0, centroid[2] - south)
# Compute normal vectors
normals = vtk.vtkPolyDataNormals()
normals.SetInputData(rec)
# Set splitting off to avoid output mesh from having different number of
# points compared to input
normals.SplittingOff()
normals.Update()
mesh = normals.GetOutput()
return mesh
[docs]
def get_reconstruction_from_coeffs(coeffs: np.array, lrec: int = 0):
"""Converts a set of spherical harmonic coefficients into
a 3d mesh.
Parameters
----------
coeffs : np.array
Input array of spherical harmonic coefficients. These
array has dimensions 2xLxM, where the first dimension
is 0 for cosine-associated coefficients and 1 for
sine-associated coefficients. Second and thrid dimensions
represent the expansion parameters (l,m).
Returns
-------
mesh : vtkPolyData
Mesh that represents the input parametric grid.
Other parameters
----------------
lrec : int, optional
Degree of the reconstruction. If lrec<l, then only
coefficients l<lrec will be used for creating the mesh.
If lrec>l, then the mesh will be oversampled.
Default is 0 meaning all coefficients
available in the matrix coefficients will be used.
Notes
-----
The mesh resolution is set by the size of the coefficients
matrix and therefore not affected by lrec.
"""
# Degree of the expansion
lmax = coeffs.shape[1]
if lrec == 0:
lrec = lmax
# Create array (oversampled if lrec>lrec)
coeffs_ = np.zeros((2, lrec, lrec), dtype=np.float32)
# Adjust lrec to the expansion degree
if lrec > lmax:
lrec = lmax
# Copy coefficients
coeffs_[:, :lrec, :lrec] = coeffs[:, :lrec, :lrec]
# Expand into a grid
grid = pyshtools.expand.MakeGridDH(coeffs_, sampling=2)
# Get mesh
mesh = get_reconstruction_from_grid(grid)
return mesh, grid
[docs]
def get_reconstruction_error(grid_input: np.array, grid_rec: np.array):
"""Compute mean square error between two parametric grids. When applied
to the input parametric grid and its corresponsing reconstructed
version, it gives an idea of the quality of the parametrization with
low values indicate good parametrization.
Parameters
----------
grid_input : np.array
Parametric grid
grid_rec : np.array
Reconstructed parametric grid
Returns
-------
mse : float
Mean square error
"""
mse = np.power(grid_input - grid_rec, 2).mean()
return mse
[docs]
def save_polydata(mesh: vtk.vtkPolyData, filename: str):
"""Saves a mesh as a vtkPolyData file.
Parameters
----------
mesh : vtkPolyData
Input mesh
filename : str
File path where the mesh will be saved
output_type : vtk or ply
Format of output polydata file
"""
# Output file format
output_type = filename.split(".")[-1]
if output_type not in ["vtk", "ply"]:
raise ValueError(
f"Output format {output_type} not supported. Please use vtk or ply."
)
if output_type == "vtk":
writer = vtk.vtkPolyDataWriter()
else:
writer = vtk.vtkPLYWriter()
writer.SetInputData(mesh)
writer.SetFileName(filename)
writer.Write()
[docs]
def convert_coeffs_dict_to_matrix(coeffs_dict, lmax=32):
"""
Convert a dictionary of SH coefficients to a matrix of SH coefficients.
The dictionary should have keys in the format "shcoeffs_L{L}M{M}{C}" where
L and M are the degree and order of the coefficient and C is either "C" or "S"
for cosine or sine coefficients, respectively.
Parameters
----------
coeffs_dict : dict
Dictionary of SH coefficients
lmax : int
Maximum degree of the SH coefficients
Returns
-------
coeffs : np.array
Matrix of SH coefficients
"""
coeffs = np.zeros((2, lmax + 1, lmax + 1), dtype=np.float32)
for L in range(lmax):
for M in range(L + 1):
for cid, C in enumerate(["C", "S"]):
coeffs[cid, L, M] = coeffs_dict[f"shcoeffs_L{L}M{M}{C}"]
return coeffs
[docs]
def voxelize_mesh(
imagedata: vtk.vtkImageData, shape: Tuple, mesh: vtk.vtkPolyData, origin: List
):
"""
Voxelize a triangle mesh into an image.
Parameters
--------------------
imagedata: vtkImageData
Imagedata that will be uses as support for voxelization.
shape: tuple
Shape that imagedata scalars will take after
voxelization.
mesh: vtkPolyData
Mesh to be voxelized
origin: List
xyz specifying the lower left corner of the mesh.
Returns
-------
img: np.array
Binary array.
"""
pol2stenc = vtk.vtkPolyDataToImageStencil()
pol2stenc.SetInputData(mesh)
pol2stenc.SetOutputOrigin(origin)
pol2stenc.SetOutputWholeExtent(imagedata.GetExtent())
pol2stenc.Update()
imgstenc = vtk.vtkImageStencil()
imgstenc.SetInputData(imagedata)
imgstenc.SetStencilConnection(pol2stenc.GetOutputPort())
imgstenc.ReverseStencilOff()
imgstenc.SetBackgroundValue(0)
imgstenc.Update()
# Convert scalars from vtkImageData back to numpy
scalars = imgstenc.GetOutput().GetPointData().GetScalars()
img = vtknp.vtk_to_numpy(scalars).reshape(shape)
return img
[docs]
def voxelize_meshes(meshes: List):
"""
List of meshes to be voxelized into an image. Usually
the input corresponds to the cell membrane and nuclear
shell meshes.
Parameters
--------------------
meshes: List
List of vtkPolydatas representing the meshes to
be voxelized into an image.
Returns
-------
img: np.array
3D image where voxels with value i represent are
those found in the interior of the i-th mesh in
the input list. If a voxel is interior to one or
more meshes form the input list, it will take the
value of the right most mesh in the list.
origin:
Origin of the meshes in the voxelized image.
"""
# 1st mesh is used as reference (cell) and it should be
# the larger than the 2nd one (nucleus).
mesh = meshes[0]
# Find mesh coordinates
coords = vtknp.vtk_to_numpy(mesh.GetPoints().GetData())
# Find bounds of the mesh
rmin = (coords.min(axis=0) - 0.5).astype(int)
rmax = (coords.max(axis=0) + 0.5).astype(int)
# Width, height and depth
w = int(2 + (rmax[0] - rmin[0]))
h = int(2 + (rmax[1] - rmin[1]))
d = int(2 + (rmax[2] - rmin[2]))
# Create image data
imagedata = vtk.vtkImageData()
imagedata.SetDimensions([w, h, d])
imagedata.SetExtent(0, w - 1, 0, h - 1, 0, d - 1)
imagedata.SetOrigin(rmin)
imagedata.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 1)
# Set all values to 1
imagedata.GetPointData().GetScalars().FillComponent(0, 1)
# Create an empty 3D numpy array to sum up
# voxelization of all meshes
img = np.zeros((d, h, w), dtype=np.uint8)
# Voxelize one mesh at the time
for mid, mesh in enumerate(meshes):
seg = voxelize_mesh(
imagedata=imagedata, shape=(d, h, w), mesh=mesh, origin=rmin
)
img[seg > 0] = mid + 1
# Origin of the reference system in the image
origin = rmin.reshape(1, 3)
return img, origin