Source code for psf_generator.utils.bessel

# Copyright Biomedical Imaging Group, EPFL 2025

"""
A collection of custom Bessel functions with gradient tracking.

These functions contain adjoint-enabled overrides for the PyTorch build-in `bessel_j0` and `bessel_j1` as
those do not have gradient tracking as of v1.13.1.

"""

__all__ =['BesselJ0', 'BesselJ1']

from typing import Any

import torch
from torch.autograd import Function
from torch.special import (
    bessel_j0,  # as __bessel_j0
    bessel_j1,  # as __bessel_j1
)


[docs] class BesselJ0(Function): """ Differentiable version of PyTorch's `bessel_j0(x)`. """ @staticmethod
[docs] def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: ctx.save_for_backward(x) ctx.save_for_forward(x) return bessel_j0(x)
@staticmethod @torch.autograd.function.once_differentiable
[docs] def vjp(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: """ Vector-Jacobian product, for reverse-mode adjoint (`backward()`). """ x, = ctx.saved_tensors return -bessel_j1(x) * grad_output
@staticmethod
[docs] def jvp(ctx: Any, grad_input: torch.Tensor) -> torch.Tensor: """ Jacobian-vector product, for forward-mode adjoint. """ x, = ctx.saved_tensors return -bessel_j1(x) * grad_input
[docs] class BesselJ1(Function): """ Differentiable version of `bessel_j1(x)`. """ @staticmethod
[docs] def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: result = bessel_j1(x) ctx.save_for_backward(x, result) ctx.save_for_forward(x, result) return result
@staticmethod @torch.autograd.function.once_differentiable
[docs] def vjp(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: """ Vector-Jacobian product, for reverse-mode adjoint (`backward()`). """ x, j1 = ctx.saved_tensors j1_norm_x = torch.where(x == 0.0, 0.5, j1 / x) jac = bessel_j0(x) - j1_norm_x return jac * grad_output
@staticmethod
[docs] def jvp(ctx: Any, grad_input: torch.Tensor) -> torch.Tensor: """ Jacobian-vector product, for forward-mode adjoint. """ x, j1 = ctx.saved_tensors j1_norm_x = torch.where(x == 0.0, 0.5, j1 / x) jac = bessel_j0(x) - j1_norm_x return jac * grad_input