Source code for pylorenzmie.theory.jaxLorenzMie

'''Lorenz-Mie hologram and Jacobian computed with JAX.

Provides a JIT-compiled forward pass and the analytical Jacobian for
all five particle parameters via forward-mode automatic differentiation
(:func:`jax.jacfwd`).

Raises :exc:`ImportError` if JAX is not installed, or a JAX
:exc:`RuntimeError` if the active backend fails a smoke test
(e.g., version mismatch between JAX and jax-metal / jax-cuda).
'''

from pylorenzmie.theory.LorenzMie import LorenzMie
from pylorenzmie.theory.Sphere import Sphere
from pylorenzmie.lib.lmtypes import Coefficients, Coordinates, Field
import numpy as np
import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)
# Verify the active backend can execute a trivial float64 op.
# jax-metal built against older StableHLO raises JaxRuntimeError here,
# which propagates out and causes __init__.py to skip this backend.
jax.jit(lambda x: x)(jnp.array(1.0))


def _downward_rec(z, nmax):
    '''Downward recurrence for the Riccati-Bessel D1 logarithmic derivative.

    Computes D1[n] for n = 0..nmax starting from D1[nmax] = 0
    (Bohren & Huffman Eq. 4.88 / Yang 2003 Eq. 16b).

    Parameters
    ----------
    z : complex JAX scalar
        Argument: x*m for the sphere interior, x for the medium.
    nmax : int
        Number of terms (static; determines scan length).

    Returns
    -------
    D1 : jnp.ndarray, shape (nmax+1,), complex128
    '''
    def step(D1_n, n):
        D1_nm1 = n / z - 1. / (D1_n + n / z)
        return D1_nm1, D1_nm1

    ns = jnp.arange(nmax, 0, -1, dtype=complex)
    _, D1_rev = jax.lax.scan(step, 0. + 0.j, ns)
    # D1_rev[k] = D1[nmax-1-k]; flip restores ascending order
    return jnp.concatenate([jnp.flip(D1_rev), jnp.array([0. + 0.j])])


def _upward_rec(x, D1_med, nmax):
    '''Upward recurrences for ψ_n, ζ_n, and D3_n (Yang 2003 Eqs. 18–21).

    Parameters
    ----------
    x : complex JAX scalar
        Size parameter (passed as complex to enable AD through sin/exp).
    D1_med : jnp.ndarray, shape (nmax+1,), complex128
        Logarithmic derivative D1 evaluated at z = x (medium).
    nmax : int
        Number of terms.

    Returns
    -------
    Psi, zeta, D3 : jnp.ndarray, each shape (nmax+1,), complex128
    '''
    Psi_0 = jnp.sin(x)
    zeta_0 = -1.j * jnp.exp(1.j * x)
    Psizeta_0 = 0.5 * (1. - jnp.exp(2.j * x))
    D3_0 = 1.j + 0. * x   # trace x so AD flows through dtype/shape

    def step(carry, n_int):
        Psi_prev, zeta_prev, Psizeta_prev, D3_prev = carry
        n = jnp.asarray(n_int, dtype=complex)
        nox = n / x
        D1_nm1 = D1_med[n_int - 1]
        D1_n = D1_med[n_int]
        Psi_n = Psi_prev * (nox - D1_nm1)
        zeta_n = zeta_prev * (nox - D3_prev)
        Psizeta_n = Psizeta_prev * (nox - D1_nm1) * (nox - D3_prev)
        D3_n = D1_n + 1.j / Psizeta_n
        return (Psi_n, zeta_n, Psizeta_n, D3_n), (Psi_n, zeta_n, D3_n)

    ns = jnp.arange(1, nmax + 1)
    _, (Psi_rest, zeta_rest, D3_rest) = jax.lax.scan(
        step, (Psi_0, zeta_0, Psizeta_0, D3_0), ns)

    Psi = jnp.concatenate([Psi_0[None], Psi_rest])
    zeta = jnp.concatenate([zeta_0[None], zeta_rest])
    D3 = jnp.concatenate([D3_0[None], D3_rest])
    return Psi, zeta, D3


def _jax_mie_ab(a_p, n_p, k_p, n_m, wavelength, nmax):
    '''JAX-differentiable Mie scattering coefficients (Wiscombe–Yang).

    nmax is a static Python int so jax.lax.scan compiles at a fixed
    length. All arithmetic is differentiable w.r.t. the scalar inputs.

    Parameters
    ----------
    a_p : JAX float scalar
        Particle radius, μm.
    n_p : JAX float scalar
        Particle refractive index.
    k_p : JAX float scalar
        Particle absorption coefficient.
    n_m : JAX float scalar
        Medium refractive index.
    wavelength : JAX float scalar
        Vacuum wavelength, μm.
    nmax : int
        Number of partial-wave terms (static).

    Returns
    -------
    ab : jnp.ndarray, shape (nmax+1, 2), complex128
    '''
    k = 2. * jnp.pi * jnp.real(n_m) / wavelength  # rad/μm in medium
    x = k * a_p                                      # size parameter
    m = (n_p + 1.j * k_p) / n_m                     # relative index

    D1_sphere = _downward_rec(x * m, nmax)           # Ha = Hb (single layer)
    D1_med = _downward_rec(x + 0.j, nmax)
    Psi, zeta, _ = _upward_rec(x + 0.j, D1_med, nmax)

    n = jnp.arange(nmax + 1, dtype=complex)
    Psir = jnp.roll(Psi, 1)
    zetar = jnp.roll(zeta, 1)

    fac_a = D1_sphere / m + n / x                    # Eq. 5
    ab_a = (fac_a * Psi - Psir) / (fac_a * zeta - zetar)
    fac_b = D1_sphere * m + n / x                    # Eq. 6
    ab_b = (fac_b * Psi - Psir) / (fac_b * zeta - zetar)

    ab = jnp.stack([ab_a, ab_b], axis=1)
    return ab.at[0].set(0. + 0.j)


def _jax_lorenzmie(ab, kdr, bohren=True):
    '''JAX partial-wave sum for the scattered electric field.

    Implements Bohren & Huffman §4.4 (Eqs. 4.45–4.50) using
    jax.lax.scan over multipole orders.  Fully differentiable
    w.r.t. ab and kdr.

    Parameters
    ----------
    ab : jnp.ndarray, shape (norders, 2), complex128
        Mie scattering coefficients.
    kdr : jnp.ndarray, shape (3, npts), float64
        Wavenumber-scaled displacement from particle to coordinates.
    bohren : bool
        Selects h_n^(1) (True, default) or h_n^(2) (False).

    Returns
    -------
    field : jnp.ndarray, shape (3, npts), complex128
        Cartesian scattered field at each coordinate.
    '''
    kx = kdr[0]
    ky = kdr[1]
    kz = -kdr[2]   # flip: particle above focal plane is +z
    npts = kx.shape[0]

    krho = jnp.hypot(kx, ky)
    kr = jnp.hypot(krho, kz)
    # Compute cosφ/sinφ directly to avoid arctan2 singularity at kx=ky=0.
    # jnp.where selects the constant branch (gradient=0) at the on-axis
    # pixel, preventing 0/0 NaN in the analytical Jacobian.
    krho_safe = jnp.where(krho > 0., krho, 1.)
    cosphi = jnp.where(krho > 0., kx / krho_safe, 1.)
    sinphi = jnp.where(krho > 0., ky / krho_safe, 0.)
    theta = jnp.arctan2(krho, kz)
    costheta = jnp.cos(theta)
    sintheta = jnp.sin(theta)
    sinkr = jnp.sin(kr)
    coskr = jnp.cos(kr)

    sgn = 1.j * jnp.sign(kz) if bohren else -1.j * jnp.sign(kz)
    xi_nm2 = (coskr + sgn * sinkr).astype(complex)  # xi_{-1}(kr)
    xi_nm1 = (sinkr - sgn * coskr).astype(complex)  # xi_0(kr)

    norders = ab.shape[0]

    init = (
        jnp.zeros(npts, dtype=complex),    # pi_nm1
        jnp.ones(npts, dtype=complex),     # pi_n
        xi_nm2,
        xi_nm1,
        jnp.zeros((3, npts), dtype=complex),  # Es (spherical)
        jnp.ones((), dtype=complex),           # En_factor (tracks i^n)
    )

    def step(carry, n_int):
        pi_nm1, pi_n, xi_nm2, xi_nm1, Es, En_fac = carry
        n = jnp.asarray(n_int, dtype=complex)

        # Legendre upward recurrence (Wiscombe 1980)
        swisc = pi_n * costheta
        twisc = swisc - pi_nm1
        tau_n = pi_nm1 - n * twisc

        # Riccati-Bessel upward recurrence
        xi_n = (2. * n - 1.) * xi_nm1 / kr - xi_nm2
        Dn = n * xi_n / kr - xi_nm1

        # vector spherical harmonics (geometric factors divided out)
        Mo1n_t = pi_n * xi_n
        Mo1n_p = tau_n * xi_n
        Ne1n_r = (n * n + n) * Mo1n_t
        Ne1n_t = tau_n * Dn
        Ne1n_p = pi_n * Dn

        En_fac = En_fac * 1.j
        En = En_fac * (2. * n + 1.) / (n * n + n)
        an = 1.j * En * ab[n_int, 0]
        bn = En * ab[n_int, 1]

        new_Es = Es.at[0].add(an * Ne1n_r)
        new_Es = new_Es.at[1].add(an * Ne1n_t - bn * Mo1n_t)
        new_Es = new_Es.at[2].add(an * Ne1n_p - bn * Mo1n_p)

        new_pi_n = (swisc + (1. + 1. / n) * twisc).astype(complex)
        return (pi_n, new_pi_n, xi_nm1, xi_n, new_Es, En_fac), None

    (_, _, _, _, Es, _), _ = jax.lax.scan(
        step, init, jnp.arange(1, norders))

    # restore geometric factors
    Es = Es.at[0].multiply(cosphi * sintheta / (kr * kr))
    Es = Es.at[1].multiply(cosphi / kr)
    Es = Es.at[2].multiply(sinphi / kr)

    # project spherical → Cartesian
    return jnp.stack([
        Es[0] * sintheta * cosphi + Es[1] * costheta * cosphi - Es[2] * sinphi,
        Es[0] * sintheta * sinphi + Es[1] * costheta * sinphi + Es[2] * cosphi,
        Es[0] * costheta - Es[1] * sintheta,
    ])


def _jax_hologram_core(r_p, a_p, n_p, k_p, k, n_m, wavelength,
                        coordinates, nmax):
    '''Pure JAX hologram, differentiable w.r.t. all non-static args.

    Parameters
    ----------
    r_p : jnp.ndarray, shape (3,)
        Particle position (pixels).
    a_p, n_p, k_p : JAX float scalars
        Sphere radius (μm), refractive index, absorption.
    k : JAX float scalar
        Wavenumber in rad/pixel.
    n_m : JAX float scalar
        Medium refractive index.
    wavelength : JAX float scalar
        Vacuum wavelength, μm.
    coordinates : jnp.ndarray, shape (3, npts)
        Pixel coordinates at which to evaluate the hologram.
    nmax : int
        Number of Mie terms (static).

    Returns
    -------
    hologram : jnp.ndarray, shape (npts,), float64
    '''
    ab = _jax_mie_ab(a_p, n_p, k_p, n_m, wavelength, nmax)
    kdr = k * (coordinates - r_p[:, None])
    field = _jax_lorenzmie(ab, kdr)
    field = field * jnp.exp(-1.j * k * r_p[2])
    field = field.at[0].add(1.)
    return jnp.sum(field.real**2 + field.imag**2, axis=0)


# nmax (arg index 8) is static: controls lax.scan length at compile time
_jax_hologram_jit = jax.jit(_jax_hologram_core, static_argnums=(8,))

# Jacobian w.r.t. r_p (0), a_p (1), n_p (2); nmax (8) still static.
# Module-level JIT avoids per-call retracing that a closure would cause.
_jax_hologram_jac = jax.jit(
    jax.jacfwd(_jax_hologram_core, argnums=(0, 1, 2)),
    static_argnums=(8,),
)


class jaxLorenzMie(LorenzMie):
    '''LorenzMie with JAX JIT compilation and analytical Jacobian.

    Overrides :meth:`hologram` with a JAX implementation compiled via
    :func:`jax.jit`.  Adds :meth:`jac` for the full analytical Jacobian
    via forward-mode automatic differentiation (:func:`jax.jacfwd`),
    covering all five particle parameters in five forward passes.

    The class attribute ``method = 'jax numpy'`` allows
    :class:`~pylorenzmie.analysis.Optimizer` to select a compatible
    model.

    Parameters
    ----------
    coordinates, particle, instrument
        Forwarded to :class:`LorenzMie`.

    Notes
    -----
    Requires JAX with 64-bit support.  ``jax_enable_x64`` is set at
    import time, which affects the global JAX session.

    The first :meth:`hologram` (or :meth:`jac`) call for a given *nmax*
    triggers JIT compilation (typically < 1 s).  Calls with the same
    particle size reuse the compiled function via XLA's compilation cache.

    Only single :class:`~pylorenzmie.theory.Sphere` particles are
    JAX-accelerated.  Multi-particle configurations fall back to the
    NumPy implementation automatically.
    '''

    method: str = 'jax numpy'
    jac_params: frozenset = frozenset({'x_p', 'y_p', 'z_p', 'a_p', 'n_p'})

    def _compute_nmax(self) -> int:
        # Use the unscaled (rad/μm) wavenumber — same as Sphere.mie_coefficients
        k = float(self.instrument.wavenumber(scaled=False))
        x = k * float(self.particle.a_p)
        m = (self.particle.n_p + 1.j * self.particle.k_p) / self.instrument.n_m
        return Sphere.wiscombe_yang(x, m)

    def hologram(self, **kwargs) -> np.ndarray:
        '''JIT-compiled hologram for a single :class:`~pylorenzmie.theory.Sphere`.

        Falls back to the NumPy implementation for non-Sphere or
        multi-particle models.  The ``cartesian`` and ``bohren`` keyword
        arguments are ignored when the JAX path is active (both are
        always True).
        '''
        if not isinstance(self.particle, Sphere):
            return super().hologram(**kwargs)

        nmax = self._compute_nmax()
        r_p = jnp.asarray(self.particle.r_p + self.particle.r_0)
        a_p = jnp.asarray(float(self.particle.a_p))
        n_p = jnp.asarray(float(self.particle.n_p))
        k_p = jnp.asarray(float(self.particle.k_p))
        k = jnp.asarray(float(self.instrument.wavenumber()))
        n_m = jnp.asarray(float(self.instrument.n_m))
        wavelength = jnp.asarray(float(self.instrument.wavelength))
        coords = jnp.asarray(self.coordinates)

        return np.asarray(_jax_hologram_jit(
            r_p, a_p, n_p, k_p, k, n_m, wavelength, coords, nmax))

    def jac(self) -> dict:
        '''Analytical Jacobian of the hologram w.r.t. particle parameters.

        Uses :func:`jax.jacfwd` (5 forward-mode passes) to compute
        dH/d(x_p, y_p, z_p, a_p, n_p) at the current model state.
        *nmax* is fixed at its current value (from the Wiscombe–Yang
        criterion); the result is exact at the current operating point
        and valid for local optimization.

        Returns
        -------
        J : dict of str → numpy.ndarray, shape (npts,), float64
            Keys: ``'x_p'``, ``'y_p'``, ``'z_p'``, ``'a_p'``, ``'n_p'``.

        Raises
        ------
        TypeError
            If :attr:`particle` is not a single :class:`~pylorenzmie.theory.Sphere`.
        '''
        if not isinstance(self.particle, Sphere):
            raise TypeError('jac() supports single Sphere particles only')

        nmax = self._compute_nmax()
        r_p = jnp.asarray(self.particle.r_p + self.particle.r_0)
        a_p = jnp.asarray(float(self.particle.a_p))
        n_p = jnp.asarray(float(self.particle.n_p))
        k_p = jnp.asarray(float(self.particle.k_p))
        k = jnp.asarray(float(self.instrument.wavenumber()))
        n_m = jnp.asarray(float(self.instrument.n_m))
        wavelength = jnp.asarray(float(self.instrument.wavelength))
        coords = jnp.asarray(self.coordinates)

        J_r, J_a, J_n = _jax_hologram_jac(
            r_p, a_p, n_p, k_p, k, n_m, wavelength, coords, nmax)

        return {
            'x_p': np.asarray(J_r[:, 0]),
            'y_p': np.asarray(J_r[:, 1]),
            'z_p': np.asarray(J_r[:, 2]),
            'a_p': np.asarray(J_a),
            'n_p': np.asarray(J_n),
        }


if __name__ == '__main__':  # pragma: no cover
    jaxLorenzMie.example()