# Copyright Biomedical Imaging Group, EPFL 2025
"""
A collection of functions related to Zernike polynomials.
"""
import warnings
import numpy as np
import torch
from scipy.special import binom
from zernikepy import zernike_polynomials
[docs]
def create_pupil_mesh(n_pixels: int) -> tuple[torch.Tensor, ...]:
"""
Create a 2D square meshgrid for the pupil function.
Parameters
----------
n_pixels : int
Number of pixels for the pupil function.
Returns
-------
(kx, ky): Tuple[torch.Tensor, ...]
Two Tensors that represent the 2D coordinates on the mesh.
"""
x = torch.linspace(-1, 1, n_pixels)
y = torch.linspace(-1, 1, n_pixels)
kx, ky = torch.meshgrid(x, y, indexing='xy')
return kx, ky
[docs]
def zernike_nl(n: int, l: int, rho: torch.float, phi: float, radius: float = 1) -> torch.Tensor:
"""
Compute the Zernike polynomial of order n and m in the polar coordinates
Parameters
----------
n : int
Index `n` in the definition on wikipedia, positive integer.
l : int
:math:`|l| = m`, `m` is the index m in the definition on wikipedia. `l` can be positive or negative.
rho : torch.Float
Radial distance.
phi : float
Azimuthal angle.
radius : float
Radius of the disk on which the Zernike polynomial is defined, default is 1.
Returns
-------
Z: torch.Tensor
Zernike polynomial Z(rho, phi) evaluated at `rho` and `phi` given indices `n` and `l`.
"""
m = abs(l)
R = 0
for k in np.arange(0, (n - m) / 2 + 1):
R = R + (-1) ** k * binom(n - k, k) * binom(n - 2 * k, (n - m) / 2 - k) * (rho / radius) ** (n - 2 * k)
# radial part
Z = torch.where(rho <= radius, R, 0)
# angular part
Z *= np.cos(m * phi) if l >= 0 else np.sin(m * phi)
return Z
[docs]
def index_to_nl(index: int) -> tuple[int, int]:
"""
Find the [n, l]-pair given OSA index l for Zernike polynomials.
The OSA index 'j' is defined as :math:`j = (n(n + 2) + l) / 2`.
Parameters
----------
index : int
OSA index j.
Returns
-------
(n, - n + 2 * l) : Tuple[int, int]
Corresponding (n, l)-pair.
"""
n = 0
while True:
for l in range(n + 1):
if n * (n + 1) / 2 + l == index:
return n, - n + 2 * l
elif n * (n + 1) / 2 + l > index:
raise ValueError('Index out of bounds.')
n += 1
[docs]
def create_zernike_aberrations(zernike_coefficients: torch.Tensor, n_pix_pupil: int, mesh_type: str) -> torch.Tensor:
"""
Create Zernike aberrations for the pupil function.
Arbitrary Zernike aberrations can be applied to the Cartesian propagator.
How it works:
- Given the Zernike coefficients as a 1D Tensor of length `n_zernike`, a stack of the first `n_zernike`
Zernike polynomials are constructed.
- Then, the coefficients and the polynomials are multiplied and summed accordingly to create a phase mask.
Finally, we create the complex field to be multiple with the existing pupil function to add this aberration.
For the Spherical case, only the axis-symmetric Zernike polynomials (i.e. only dependent on the radius `rho` not the
angle `phi`), such as _defocus_ and 'primary spherical', can be applied due to the axis-symmetric assumption of the
spherical propagator. See `Spherical_propagators.py` for details.
Parameters
----------
zernike_coefficients : torch.Tensor
1D Tensor of Zernike coefficients
n_pix_pupil : int
Number of pixels of the pupil function
mesh_type : str
Choose 'spherical' or 'cartesian'.
Returns
-------
Zernike_aberrations: torch.Tensor
Of type torch.complex64.
"""
n_zernike = len(zernike_coefficients)
if mesh_type == 'cartesian':
zernike_basis = zernike_polynomials(mode=n_zernike-1, size=n_pix_pupil, select='all')
zernike_coefficients = zernike_coefficients.reshape(1, 1, n_zernike)
zernike_phase = torch.sum(zernike_coefficients * torch.from_numpy(zernike_basis), dim=2)
elif mesh_type == 'spherical':
rho = torch.linspace(0, 1, n_pix_pupil)
phi = 0
zernike_phase = torch.zeros(n_pix_pupil)
for i in range(n_zernike):
n, l = index_to_nl(index=i)
curr_coefficient = zernike_coefficients[i]
if l != 0 and curr_coefficient != 0:
warnings.warn("Warning: Zernike polynomials that are not axis-symmetric \
are not supported in spherical coordinates!")
elif l == 0:
zernike_phase += curr_coefficient * zernike_nl(n=n, l=l, rho=rho, phi=phi)
else:
raise ValueError(f"Invalid mesh type {mesh_type}, choose 'spherical' or 'cartesian'.")
return torch.exp(1j * zernike_phase).to(torch.complex64)
[docs]
def create_special_pupil(n_pix_pupil: int, mask = None, tophat_radius: float = 0.5) -> torch.Tensor:
"""
Special phase masks not included in the space spanned by the Zernike polynomials.
The supported special phase masks are:
- None <-> flat phase, Gaussian beam
- `vortex` <-> donut beam
- `halfmoon-h` <-> horizontal halfmoon beam
- `halfmoon-v` <-> vertical halfmoon beam
- `tophat` <-> tophat beam
Notes
-----
These special masks only applies in the Cartesian case.
Parameters
----------
n_pix_pupil : int
Number of pixels on the pupil plane.
name : str
Name of the special phase mask. Valid choices: None, 'vortex', 'halfmoon-h', 'halfmoon-v', 'tophat'.
tophat_radius : float
Radius of the tophat mask. Default is 0.5. TODO: relate to cutoff frequency of the system.
Returns
-------
pupil : torch.Tensor
Pupil function of the special phase mask.
"""
valid_names = [None, 'vortex', 'halfmoon-h', 'halfmoon-v', 'tophat', 'custom']
kx, ky = create_pupil_mesh(n_pixels=n_pix_pupil)
if mask is None:
phase_mask = torch.zeros(n_pix_pupil, n_pix_pupil)
elif isinstance(mask, torch.Tensor):
if mask.shape != (n_pix_pupil, n_pix_pupil):
raise ValueError(f"Custom phase mask must be a 2D Tensor of shape ({n_pix_pupil}, {n_pix_pupil}).")
phase_mask = mask
elif mask == 'vortex':
phase_mask = torch.atan2(kx, ky)
elif mask == 'halfmoon-h':
phase_mask = torch.zeros(n_pix_pupil, n_pix_pupil)
phase_mask[0: n_pix_pupil // 2, :] = torch.pi
elif mask == 'halfmoon-v':
phase_mask = torch.zeros(n_pix_pupil, n_pix_pupil)
phase_mask[:, 0: n_pix_pupil // 2] = torch.pi
elif mask == 'tophat':
inner_disk = kx ** 2 + ky ** 2 - tophat_radius ** 2
phase_mask = torch.where(inner_disk > 0, torch.pi, 0)
else:
raise ValueError(f"Invalid mask value {mask}. Must be None, a valid string, or a custom tensor")
pupil = torch.exp(1j * phase_mask).to(torch.complex64)
return pupil