Source code for psf_generator.utils.czt

# Copyright Biomedical Imaging Group, EPFL 2025

"""
A collection of custom 2D FFT functions.

- `custom_fft2`: custom 2D FFT
- `custom_ifft2`: custom 2D inverse FFT
- `czt1d`: 1D chirp Z-transform
- `czt2d`: 2D chirp Z-transform

"""

__all__ = ['custom_fft2', 'custom_ifft2']

import typing as tp

import torch
from torch.fft import fft, fft2, ifft, ifft2


def _validate_shape(shape_in: tp.Tuple, shape_out: tp.Optional[tp.Tuple]) -> tp.Tuple:
    """
    Validate the shape of the input and output.

    Parameters
    ----------
    shape_in : sequence of ints
        Shape of the last two dimension of the input image.
    shape_out : sequence of ints or None
        Shape of the output image.

    Returns
    -------
    shape_out : tuple
        processed output shape

    """
    if shape_in[0] != shape_in[1]:
        raise ValueError('The input image must be square!')
    elif shape_out is None:
        shape_out = shape_in
    K, L = shape_out[-2:]
    if K != L:
        print('Warning: Output of different size in each dimension; enforcing squared output.')
        shape_out = (max(K, L), max(K, L))
    return shape_out

def _create_w_phase(start: float, end: float, steps: int, include_end: bool) -> float:
    """
    Create the W factor.

    Parameters
    ----------
    start : float
        start point of the sampling
    end : float
        end point of the sampling
    steps : int
        number of sampling steps
    include_end : bool
        whether to include the end point

    Returns
    -------
    w_phase : float
        W factor
    """
    points = steps - 1 if include_end else steps
    w_phase = (end - start) / points
    return w_phase

def _apply_fftshift(x: torch.Tensor, shape_out: tp.Optional[tp.Tuple], k_start: float,
                    K: int, N: int, w_phase: float, a_phase: float) -> torch.Tensor:
    """
    Apply fftshift on the input image.

    Parameters
    ----------
    x : torch.Tensor
        Input 2D image.
    shape_out : tuple or None
        shape of the output image.
    k_start : float
        start point of sampling.
    K : int
        size of the output image.
    N : int
        size of the input image.
    w_phase : float
        W factor
    a_phase : float
        A factor.

    Returns
    -------
    output: torch.Tensor
        fftshifted image.
    """
    k = torch.arange(K)
    kx, ky = torch.meshgrid(k, k, indexing='ij')
    center_correction = torch.exp(1j * (N - 1) / 2 * (2 * k_start - w_phase * (kx + ky))).to(x.device)
    return czt2d(x, shape_out, w_phase, a_phase) * center_correction


[docs] def custom_fft2(x: torch.Tensor, shape_out=None, k_start: float = 0.0, k_end: float = 2 * torch.pi, norm: str = 'ortho', fftshift_input: bool = False, include_end: bool = False) -> torch.Tensor: r""" Custom 2D FFT that allows to zoom on the region of interest in the Fourier plane. The output image is square. Parameters ---------- x : torch.Tensor Input square image shape_out : sequence of ints, optional Shape of the output image. If None, same as the shape of the input. k_start : float Start point of sampling on the complex circle. Default is `0.0`. k_end : float End point of sampling on the complex circle. Default is :math:`2\pi`. norm : {"ortho", "forward", "backward"}, optional Normalization mode. Default is "ortho". fftshift_input : bool, optional Whether to apply fftshift on the input image. Default is "False". include_end : bool, optional Whether to include the end point of sampling. Default is "False". Returns ------- output : torch.Tensor Custom 2D FFT of the input image. """ shape_out = _validate_shape(x.shape[-2:], shape_out) K = shape_out[0] N = x.shape[-2] w_phase = _create_w_phase(k_start, k_end, K, include_end) a_phase = k_start if fftshift_input: result = _apply_fftshift(x, shape_out, k_start, K, N, w_phase, a_phase) else: result = czt2d(x, shape_out, w_phase, a_phase) if norm =='ortho': return result / K elif norm == 'forward': return result / K**2 elif norm == 'backward': return result
[docs] def custom_ifft2(x: torch.Tensor, shape_out=None, k_start: float = 0.0, k_end: float = 2 * torch.pi, norm: str = 'ortho', fftshift_input: bool = False, include_end: bool = False) -> torch.Tensor: r""" Custom 2D inverse FFT that allows to zoom on the region of interest in the Fourier plane. The output image is square. Parameters ---------- x : torch.Tensor Input square image. shape_out : sequence of ints, optional Shape of the output image. If None, same as the shape of the input. k_start : float Start point of sampling on the complex circle. Default is `0.0`. k_end : float End point of sampling on the complex circle. Default is :math:`2\pi`. norm : {"ortho", "forward", "backward"}, optional Normalization mode. Default is "ortho". fftshift_input : bool, optional Whether to apply fftshift on the input image. Default is "False". include_end : bool, optional Whether to include the end point of sampling. Default is "False". Returns ------- output : torch.Tensor Custom 2D inverse FFT of the input image. """ shape_out = _validate_shape(x.shape[-2:], shape_out) K = shape_out[0] N = x.shape[-2] w_phase = _create_w_phase(k_start, k_end, K, include_end) a_phase = - k_start if fftshift_input: result = _apply_fftshift(x, shape_out, k_start, K, N, w_phase, a_phase) angle = 2 * (N - 1) * k_end if not isinstance(angle, torch.Tensor): angle = torch.tensor(angle) result *= torch.exp(1j * angle) else: result = czt2d(x, shape_out, w_phase, a_phase) if norm =='ortho': return result / K elif norm == 'forward': return result elif norm == 'backward': return result / K**2
def czt1d(x: torch.Tensor, shape_out=None, w_phase=None, a_phase: torch.Tensor = 0.0) -> torch.Tensor: """ 1D chirp Z-transform implemented in PyTorch. Parameters ---------- x : torch.Tensor Input 1D image. shape_out : int, optional Length of the output image. If None, same as the length of the input. w_phase : torch.Tensor, optional W factor in the definition of the CZT. a_phase : torch.Tensor, optional A factor in the definition of the CZT Returns ------- output: torch.Tensor CZT of the input 1D image. References ---------- .. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.CZT.html """ shape_in = x.shape[-1] if shape_out is None: shape_out = shape_in if w_phase is None: w_phase = - 2 * torch.pi / shape_out max_dim = max(shape_in, shape_out) fft_dim = int(2 ** torch.ceil(torch.log2(torch.tensor(shape_in + shape_out - 1)))) device = x.device k = torch.arange(max_dim, device=device) wk2 = torch.exp(1j * w_phase * k ** 2 / 2).to(device) aw_factor = torch.exp(- 1j * a_phase * k[:shape_in]).to(device) * wk2[:shape_in] second_factor = fft( 1 / torch.hstack((torch.flip(wk2[1:shape_in], dims=(0,)), wk2[:shape_out])), fft_dim) idx = slice(shape_in - 1, shape_in + shape_out - 1) output = ifft(fft(x * aw_factor, n=fft_dim) * second_factor) output = wk2[:shape_out] * output[..., idx] return output def czt2d(x: torch.Tensor, shape_out=None, w_phase=None, a_phase: float = 0.0) -> torch.Tensor: """ 2D chirp Z-transform implemented in PyTorch. Parameters ---------- x : torch.Tensor Input 2D image. shape_out : sequence of ints, optional Shape of the output image. If None, same as the length of the input. w_phase : float, optional W factor in the definition of the CZT. a_phase : float, optional A factor in the definition of the CZT Returns ------- output: torch.Tensor CZT of the input 2D image. """ shape_out = _validate_shape(x.shape[-2:], shape_out) K = shape_out[0] N = x.shape[-2] if w_phase is None: w_phase = - 2 * torch.pi / K max_dim = max(N, K) fft_dim = int(2 ** torch.ceil(torch.log2(torch.tensor(N + K - 1)))) device = x.device k = torch.arange(max_dim).to(device) kx, ky = torch.meshgrid(k, k, indexing='ij') wk2 = torch.exp(1j * w_phase * (kx**2+ky**2) / 2).to(device) aw_factor = torch.exp(- 1j * a_phase * (kx[:N, :N]+ky[:N, :N])).to(device) * wk2[:N, :N] second_factor = fft2(1 / torch.hstack( (torch.vstack( (torch.flip(wk2[1:N, 1:N], dims=(0,1)), torch.flip(wk2[:K, 1:N], dims=(1,))), ), torch.vstack( (torch.flip(wk2[1:N, :K], dims=(0,)), wk2[:K, :K]), )), ), s=(fft_dim,fft_dim)) idx = slice(N - 1, N + K - 1) output = ifft2(fft2(x * aw_factor, s=(fft_dim,fft_dim)) * second_factor) output = wk2[:K, :K] * output[..., idx, idx] return output