"""
A collection of plotting functions.
"""
import os
import typing as tp
import warnings
import matplotlib.pyplot as plt
import numpy as np
import torch
from mpl_toolkits.axes_grid1 import make_axes_locatable
from .misc import convert_tensor_to_array
[docs]
def colorbar(mappable, cbar_ticks: tp.Union[str, tp.List, None] = 'auto', tick_size: float = _TICK_SIZE,
cbar_labels: tp.List[str] = None):
"""
Colorbar with the option to add or remove ticks.
Parameters
----------
mappable :
Matplotlib Mappable.
cbar_ticks : None or str or List of ticks
If None, no ticks visible. If 'auto': ticks are determined automatically. Otherwise, set the ticks as given by cbar_ticks.
tick_size: float, optional
Fontsize of the tick labels.
cbar_labels: list[str], optional
Cbar labels. Default is None, use cbar ticks.
"""
last_axes = plt.gca()
ax = mappable.axes
fig = ax.figure
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(mappable, cax=cax)
if cbar_ticks == 'auto':
pass
elif cbar_ticks is None:
cbar.set_ticks([])
else:
cbar.set_ticks(cbar_ticks)
if cbar_labels is not None:
if len(cbar_labels) != len(cbar_ticks):
raise ValueError('The length of the cbar labels and ticks are different.')
else:
cbar.set_ticklabels(cbar_labels, fontsize=tick_size)
else:
cbar.set_ticklabels([f'{tick:.2f}' for tick in cbar_ticks], fontsize=tick_size)
plt.sca(last_axes)
return cbar
[docs]
def apply_disk_mask(img):
"""Apply a disk mask to a square image."""
img = img.copy()
lx, ly = img.shape
diameter = max(lx, ly)
# check if square
if lx != ly:
msg = f'Image is non-square, shape: {img.shape}. Applying an over-sized disk mask!'
warnings.warn(msg)
# create mask
mask = np.zeros((lx, ly))
i = np.linspace(0, lx, lx)
j = np.linspace(0, ly, ly)
ii, jj = np.meshgrid(i, j, indexing='ij')
disk = (ii - lx // 2) ** 2 + (jj - ly // 2) ** 2 <= (diameter // 2) ** 2
mask[disk] = 1
# apply mask, set values outside the mask to nan
img = np.where(mask, img, np.nan)
return img
[docs]
def _compute_psf_intensity(input_image: np.ndarray) -> np.ndarray:
r"""
Compute the intensity of a complex field.
The input array must be 4D with this convention:
- dim one: z axis, or defocus slices.
- dim two: electric field components. Only one for scalar and three :math:`(\mathbf{e}_x, \mathbf{e}_y, \mathbf{e}_z)` for vectorial.
- dim three and four: :math:`(x, y)` axes.
The intensity is computed as follows:
.. math:: I = \sum_{i=1}^{N} |\mathbf{e}_i(x, y, z)|^2, \quad N = 1 \, \mathrm{or} \, 3.
Parameters
----------
input_image : np.ndarray
Scalar or vectorial complex field. 4D array.
Returns
-------
output : np.ndarray
Intensity of the field. 4D array.
"""
if input_image.ndim != 4:
raise ValueError(f'The input image must be 4D instead of {input_image.ndim}')
else:
intensity = np.sum(np.abs(input_image) ** 2, axis=1)
return intensity[:, np.newaxis, :, :]
[docs]
def plot_pupil(
pupil: tp.Union[torch.Tensor, np.ndarray],
name_of_propagator: str,
filepath: str = None,
show_cbar_ticks: bool = False,
show_image_ticks: bool = False,
show_titles: bool = True,
):
"""
Plot the modulus and phase of a scalar or vectorial pupil for the Cartesian propagator.
Parameters
----------
pupil : torch.Tensor or np.ndarray
Pupil image to plot.
name_of_propagator : str
Name of the propagator.
filepath: str, optional
Path to save the plot. Default is None, no file is saved.
show_titles : bool, optional
Whether to show the titles on the first row. Default is False.
show_image_ticks : bool, optional
Whether to show ticks. Default is False.
show_cbar_ticks : bool, optional
Whether to show the ticks for the colorbar. Default is False.
"""
if 'spherical' in name_of_propagator:
raise NotImplementedError('For spherical propagators, the pupil is represented by two 1D intervals, '
'no 2D image is thus available. '
'Please check the pupil of the equivalent Cartesian propagator instead.')
# convert to numpy array
pupil_array = convert_tensor_to_array(pupil).squeeze()
# compute modulus and phase
pupil_modulus = np.abs(pupil_array)
pupil_phase = np.angle(pupil_array)
pupil_list = [pupil_modulus, pupil_phase]
if pupil_array.ndim == 2:
nrows = 1
pupil_list = [x[np.newaxis, :, :] for x in pupil_list]
row_titles = ['']
elif pupil_array.ndim == 3:
nrows = pupil_array.shape[0]
row_titles = [r'$\mathbf{e}_x$', r'$\mathbf{e}_y$', r'$\mathbf{e}_z$']
else:
raise ValueError(f'Pupil should be either 2D or 3D, not {pupil_array.ndim}')
ncols = 2
cmaps = ['inferno', 'twilight']
col_titles = ['modulus', 'phase']
figure, axes = plt.subplots(nrows, ncols, figsize=(ncols * _FIG_SIZE, nrows * _FIG_SIZE))
if nrows == 1:
axes = axes.reshape(1, -1)
axes = axes.T
for (col_index, axis), pupil, cmap, title in zip(enumerate(axes), pupil_list, cmaps, col_titles):
cbar_min = np.min(pupil)
cbar_max = np.max(pupil)
norm = plt.Normalize(cbar_min, cbar_max)
if show_cbar_ticks:
cbar_ticks = [cbar_min, cbar_max]
else:
cbar_ticks = None
for (row_index, ax), image, row_title in zip(enumerate(axis), pupil, row_titles):
im = ax.imshow(apply_disk_mask(image), norm=norm, cmap=cmap)
colorbar(im, cbar_ticks=cbar_ticks)
if show_image_ticks:
x_ticks = [0, image.shape[1]]
xtick_labels = x_ticks
ax.set_xticks(x_ticks)
ax.set_xticklabels(xtick_labels, fontsize=_TICK_SIZE)
y_ticks = [0, image.shape[0]]
ax.set_yticks(y_ticks)
ytick_labels = y_ticks
ax.set_yticklabels(ytick_labels, fontsize=_TICK_SIZE)
else:
ax.set_xticks([])
ax.set_yticks([])
if show_titles and row_index == 0:
ax.set_title(title, fontsize=_TITLE_SIZE)
if nrows > 1 and col_index == 0:
ax.text(-0.1, 0.5, row_title, fontsize=_TITLE_SIZE, verticalalignment='center',
rotation=90, transform=ax.transAxes)
plt.subplots_adjust(left=0.05)
plt.suptitle(f'Pupil properties ({name_of_propagator})', fontsize=_SUP_TITLE_SIZE)
if filepath is not None:
figure.tight_layout()
os.makedirs(os.path.dirname(filepath), exist_ok=True)
figure.savefig(filepath)
plt.show()
[docs]
def plot_psf(
psf: tp.Union[torch.Tensor, np.ndarray],
name_of_propagator: str,
quantity: str = 'modulus',
z_slice_number: int = None,
x_slice_number: int = None,
y_slice_number: int = None,
filepath: str = None,
show_cbar_ticks: bool = False,
show_image_ticks: bool = False,
show_titles: bool = False,
propagator=None,
):
"""
Plot the intensity or modulus or phase of a PSF, applicable to all four propagators.
Parameters
----------
psf : torch.Tensor or np.ndarray
PSF image to plot.
name_of_propagator : str
Name of the propagator.
quantity : str, optional
Quantity of the PSF to plot. Default is 'modulus'. Valid choices are 'modulus', 'phase', 'stationary_phase', 'intensity', 'amplitude'.
z_slice_number : int, optional
Z slice number for the x-y plane.
x_slice_number : int, optional
X slice number for the y-z plane.
y_slice_number : int, optional
Y slice number for the x-z plane.
filepath : str, optional
Path to save the plot. Default is None, no file is saved.
show_titles : bool, optional
Whether to show the titles on the first row. Default is False.
show_image_ticks : bool, optional
Whether to show ticks. Default is False.
show_cbar_ticks : bool, optional
Whether to show the ticks for the colorbar. Default is False.
propagator : Propagator, optional
Propagator object. Required for 'stationary_phase' quantity. Default is None.
"""
# convert to numpy array
psf_array = convert_tensor_to_array(psf)
# check and compute quantity
valid_choices = ['modulus', 'phase', 'stationary_phase', 'intensity', 'amplitude']
if quantity == 'modulus':
psf_quantity = np.abs(psf_array)
cmap = 'inferno'
elif quantity == 'phase':
psf_quantity = np.angle(psf_array)
cmap = 'twilight'
elif quantity == 'stationary_phase':
# Remove the plane wave e^(ikz) contribution by multiplying by e^(-ikz)
if propagator is None:
raise ValueError('propagator parameter is required for stationary_phase quantity')
zz = torch.linspace(propagator.defocus_min, propagator.defocus_max, propagator.n_defocus)
correction = torch.exp(- 1j * propagator.k * zz * propagator.refractive_index)
correction = convert_tensor_to_array(correction)
number_of_pixel_z, dim, number_of_pixel_x, number_of_pixel_y = psf_array.shape
# Create a copy of the PSF array to modify
psf_stationary = psf_array.copy()
# Multiply each z-slice by the conjugate of the plane wave factor exp(-ikz)
for z in range(number_of_pixel_z):
for d in range(dim):
psf_stationary[z, d, :, :] = psf_array[z, d, :, :] * correction[z]
# Now compute the phase of the stationary field
psf_quantity = np.angle(psf_stationary)
cmap = 'twilight'
elif quantity == 'intensity':
psf_quantity = _compute_psf_intensity(psf_array)
cmap = 'inferno'
elif quantity == 'amplitude':
psf_quantity = np.sqrt(_compute_psf_intensity(psf_array))
cmap = 'inferno'
else:
raise ValueError(f'quantity {quantity} is not supported, choose from {valid_choices}')
number_of_pixel_z, dim, number_of_pixel_x, number_of_pixel_y = psf_quantity.shape
if z_slice_number is None:
z_slice_number = int(number_of_pixel_z // 2)
if x_slice_number is None:
x_slice_number = int(number_of_pixel_x // 2)
if y_slice_number is None:
y_slice_number = int(number_of_pixel_y // 2)
psf_quantity = psf_quantity.swapaxes(0, 1)
if dim == 1:
row_titles = ['']
elif dim == 3:
row_titles = [r'$\mathbf{e}_x$', r'$\mathbf{e}_y$', r'$\mathbf{e}_z$']
else:
raise ValueError(f'Number of channels of the PSF should be 1 or 3, not {dim}')
if number_of_pixel_z == 1: # 2D PSF
psf_slice = psf_quantity[:, 0, :, :]
cbar_min = np.min(psf_slice)
cbar_max = np.max(psf_slice)
norm = plt.Normalize(cbar_min, cbar_max)
figure, axes = plt.subplots(dim, 1, figsize=(1 * _FIG_SIZE, dim * _FIG_SIZE))
if dim == 1:
axes = [axes]
for row_index, (ax, image, row_title) in enumerate(zip(axes, psf_slice, row_titles)):
im = ax.imshow(image, norm=norm, cmap=cmap)
colorbar(im, cbar_ticks=[cbar_min, cbar_max] if show_cbar_ticks else None)
if show_titles:
ax.set_title('XY-plane (2D PSF)', fontsize=_TITLE_SIZE)
if dim > 1 :
ax.set_ylabel(row_title, fontsize=_LABEL_SIZE)
plt.subplots_adjust(left=0.05)
if show_image_ticks:
x_ticks = [0, image.shape[1]]
ax.set_xticks(x_ticks)
ax.set_xticklabels(x_ticks, fontsize=_TICK_SIZE)
y_ticks = [0, image.shape[0]]
ax.set_yticks(y_ticks)
ax.set_yticklabels(y_ticks, fontsize=_TICK_SIZE)
else:
ax.set_xticks([])
ax.set_yticks([])
plt.suptitle(f'{quantity.capitalize()} ({name_of_propagator.capitalize()})', fontsize=_SUP_TITLE_SIZE)
else: # 3D PSF
psf_list = [
psf_quantity[:, z_slice_number, :, :],
psf_quantity[:, :, x_slice_number, :],
psf_quantity[:, :, :, y_slice_number],
]
nrows = dim
ncols = len(psf_list)
col_titles = [
f'XY-plane (z={z_slice_number+1}/{number_of_pixel_z} slice)',
f'ZY plane (x={x_slice_number+1}/{number_of_pixel_x} slice)',
f'ZX plane (y={y_slice_number+1}/{number_of_pixel_y} slice)',
]
cbar_min = min(np.min(psf) for psf in psf_list)
cbar_max = max(np.max(psf) for psf in psf_list)
norm = plt.Normalize(cbar_min, cbar_max)
if show_cbar_ticks:
cbar_ticks = [cbar_min, cbar_max]
else:
cbar_ticks = None
figure, axes = plt.subplots(nrows, ncols, figsize=(ncols * _FIG_SIZE, nrows * _FIG_SIZE))
if dim == 1:
axes = axes.reshape(1, -1)
axes = axes.T
for (col_index, axis), psf, col_title in zip(enumerate(axes), psf_list, col_titles):
for (row_index, ax), image, row_title, in zip(enumerate(axis), psf, row_titles):
im = ax.imshow(image, norm = norm, cmap=cmap)
colorbar(im, cbar_ticks=cbar_ticks)
if show_titles and row_index == 0:
ax.set_title(col_title, fontsize=_TITLE_SIZE)
if dim > 1 and col_index == 0:
ax.set_ylabel(row_title, fontsize=_LABEL_SIZE)
plt.subplots_adjust(left=0.05)
if show_image_ticks:
x_ticks = [0, image.shape[1]]
xtick_labels = x_ticks
ax.set_xticks(x_ticks)
ax.set_xticklabels(xtick_labels, fontsize=_TICK_SIZE)
y_ticks = [0, image.shape[0]]
ax.set_yticks(y_ticks)
ytick_labels = y_ticks
ax.set_yticklabels(ytick_labels, fontsize=_TICK_SIZE)
else:
ax.set_xticks([])
ax.set_yticks([])
plt.suptitle(f'{quantity.capitalize()} at three orthogonal planes ({name_of_propagator.capitalize()})', fontsize=_SUP_TITLE_SIZE)
if filepath is not None:
figure.tight_layout()
os.makedirs(os.path.dirname(filepath), exist_ok=True)
figure.savefig(filepath)
plt.show()