Source code for psf_generator.propagators.propagator

# Copyright Biomedical Imaging Group, EPFL 2025

"""
The abstract propagator class.

"""
import json
import os
from abc import ABC, abstractmethod

import torch

from ..utils.misc import convert_tensor_to_array


[docs] class Propagator(ABC): r""" Base class propagator. Parameters ---------- n_pix_pupil : int, optional Number of pixels (size) of the pupil (always a square image). Default value is `128`. n_pix_psf : int, optional Number of pixels (size) of the PSF (always a square image). Default value is `128`. device : str, optional Computational backend. Choose from 'cpu' and 'gpu'. Default value is `'cpu'`. zernike_coefficients : np.ndarray or torch.tensor, optional Zernike coefficients of length 'K' of the chosen first 'K' modes. Default is `None`. wavelength : float, optional Wavelength of light, in nanometer. Default value is `632`. na : float, optional Numerical aperture. Default value is `1.3`. pix_size : float, optional Camera pixel size, in nanometer. Default value is `20`. defocus_step : float, optional Step size of the defocus along the optical (z) axis on one side of the focal plane in nanometer. Default value is `0.0`. n_defocus : int, optional Number of z-stack. Default value is `1`. apod_factor : bool, optional Apply apodization factor or not. Default value is `False`. envelope : float, optional Size :math:`k_{\mathrm{env}}` of the Gaussian envelope :math:`A(\mathbf{s}) = \mathrm{e}^{-(k^2_x+k^2_y)/k_\mathrm{env}^2}`. Default is `None`. gibson_lanni : bool, optional Apply Gibson-Lanni aberration correction or not. Default value is `False`. z_p : float, optional Depth of the focal plane in the sample. It is usually obtained experimentally by focusing on a point source at this depth. Default value is `1e3`. n_s : float, optional Refractive index of the sample. Default value is `1.3`. n_g : float, optional Refractive index of the (glass) cover slip. Default value is `1.5`. n_g0 : float, optional Design condition of the refractive index of the cover slip. Default value is `1.5`. t_g : float, optional Thickness of the sample. Default value is `170e3`. t_g0 : float, optional Design condition of the thickness of the sample. Default value is `170e3`. n_i : float, optional Refractive index of the immersion medium. Default value is `1.5`. n_i0 : float, optional Design condition of the refractive index of the immersion medium. Default value is `1.5`. t_i0 : float, optional Design condition of the thickness of the immersion medium. Default value is `100e3`. Notes ----- Internal parameters: 1. t_i : float, thickness of the immersion medium. It is computed from :math:`t_i = z_p - z + n_i \left( -\frac{z_p}{n_s} - \frac{t_g}{n_g} + \frac{t_g^0}{n_g^0} + \frac{t_i^0}{n_i^0} \right)`. 2. refractive_index : float, refractive index of the propagation medium. It is equal to :math:`n_s` if gibson_lanni=True, :math:`n_i`, otherwise. 3. `(z_p, n_s, n_g, n_g0, t_g, t_g0, n_i, t_i0, t_i)` are coefficients related to the aberrations due to refractive index mismatch between stratified layers of the microscope. This aberration is computed by method `self.compute_optical_path`. """ def __init__(self, n_pix_pupil: int =128, n_pix_psf: int = 128, device: str = 'cpu', zernike_coefficients=None, wavelength: float = 632, na: float = 1.3, pix_size: float = 20, defocus_step: float = 0.0, n_defocus: int = 1, apod_factor: bool = False, envelope=None, gibson_lanni: bool = False, z_p: float = 1e3, n_s: float = 1.3, n_g: float = 1.5, n_g0: float = 1.5, t_g: float = 170e3, t_g0: float = 170e3, n_i: float = 1.5, n_i0: float = 1.5, t_i0: float = 100e3):
[docs] self.n_pix_pupil = n_pix_pupil
[docs] self.n_pix_psf = n_pix_psf
[docs] self.device = device
if zernike_coefficients is None: zernike_coefficients = [0] if not isinstance(zernike_coefficients, torch.Tensor): zernike_coefficients = torch.tensor(zernike_coefficients)
[docs] self.zernike_coefficients = zernike_coefficients
[docs] self.wavelength = wavelength
[docs] self.na = na
[docs] self.pix_size = pix_size
[docs] self.fov = pix_size * n_pix_psf
[docs] self.defocus_step = defocus_step
[docs] self.n_defocus = n_defocus
[docs] self.defocus_min = -defocus_step * n_defocus // 2
[docs] self.defocus_max = defocus_step * n_defocus // 2
[docs] self.apod_factor = apod_factor
[docs] self.envelope = envelope
[docs] self.gibson_lanni = gibson_lanni
[docs] self.z_p = z_p
[docs] self.n_s = n_s
[docs] self.n_g = n_g
[docs] self.n_g0 = n_g0
[docs] self.t_g = t_g
[docs] self.t_g0 = t_g0
[docs] self.n_i = n_i
[docs] self.n_i0 = n_i0
[docs] self.t_i0 = t_i0
[docs] self.t_i = n_i * (t_g0 / n_g0 + t_i0 / self.n_i0 - t_g / n_g - z_p / n_s)
if gibson_lanni: self.refractive_index = n_s else: self.refractive_index = n_i @classmethod @abstractmethod
[docs] def get_name(cls) -> str: """Get name of the propagator in a certain format, e.g. 'scalar_cartesian'.""" raise NotImplementedError
@abstractmethod
[docs] def initialize_input_field(self) -> torch.Tensor: """Initialize the input field of propagator.""" raise NotImplementedError
@abstractmethod
[docs] def get_pupil(self) -> torch.Tensor: """Get the pupil function with all corrections applied.""" raise NotImplementedError
@abstractmethod
[docs] def compute_focus_field(self) -> torch.Tensor: """Compute the output field of the propagator at focal plane.""" raise NotImplementedError
[docs] def update_zernike_coefficients(self, zernike_coefficients): """Update Zernike coefficients without reinitializing propagator.""" if not isinstance(zernike_coefficients, torch.Tensor): zernike_coefficients = torch.tensor(zernike_coefficients) self.zernike_coefficients = zernike_coefficients # Recompute Zernike aberrations if method exists if hasattr(self, '_compute_zernike_aberrations'): self._compute_zernike_aberrations()
[docs] def compute_optical_path(self, sin_t: torch.Tensor) -> torch.Tensor: r"""Compute the optical path following Eq. (3.45) in [1]_. .. math:: W(\mathbf{s}) &= k \left( t_s \sqrt{n_s^2 - n_i^2 \sin^2 \theta} + t_i \sqrt{n_i^2 - n_i^2 \sin^2 \theta} -t_i^* \sqrt{\left.n_i^*\right.^2 - n_i^2 \sin^2 \theta} \right. \\ & \quad \left. + t_g \sqrt{n_g^2 - n_i^2 \sin^2 \theta} - t_g^* \sqrt{\left.n_g^*\right.^2 - n_i^2 \sin^2 \theta}\right). References ---------- .. [1] https://bigwww.epfl.ch/publications/aguet0903.pdf """ path = self.z_p * torch.sqrt(self.n_s ** 2 - self.n_i ** 2 * sin_t ** 2) \ + self.t_i * torch.sqrt(self.n_i ** 2 - self.n_i ** 2 * sin_t ** 2) \ - self.t_i0 * torch.sqrt(self.n_i0 ** 2 - self.n_i ** 2 * sin_t ** 2) \ + self.t_g * torch.sqrt(self.n_g ** 2 - self.n_i ** 2 * sin_t ** 2) \ - self.t_g0 * torch.sqrt(self.n_g0 ** 2 - self.n_i ** 2 * sin_t ** 2) return path
[docs] def _get_args(self) -> dict: """Get the parameters of the propagator.""" args = { 'n_pix_pupil': self.n_pix_pupil, 'n_pix_psf': self.n_pix_psf, 'device': self.device, 'zernike_coefficients': convert_tensor_to_array(self.zernike_coefficients).tolist(), 'wavelength': self.wavelength, 'na': self.na, 'pix_size': self.pix_size, 'refractive_index': self.refractive_index, 'defocus_step': self.defocus_step, 'n_defocus': self.n_defocus, 'apod_factor': self.apod_factor, 'envelope': self.envelope, 'gibson_lanni': self.gibson_lanni, 'z_p': self.z_p, 'n_s': self.n_s, 'n_g': self.n_g, 'n_g0': self.n_g0, 't_g': self.t_g, 't_g0': self.t_g0, 'n_i': self.n_i, 't_i0': self.t_i0, 't_i': self.t_i, } return args
[docs] def save_parameters(self, json_filepath: str): r""" Save the parameters of the propagator in a JSON file. Notes ----- - Zernike coefficients are converted to a list - complex numbers, e.g. e0x or e0y, are converted to a string Parameters ---------- json_filepath : str, optional Path to save the attributes in a JSON file. """ args = self._get_args() os.makedirs(os.path.dirname(json_filepath), exist_ok=True) with open(json_filepath, 'w') as file: json.dump(args, file, indent=2)