Source code for ptyrodactyl.electrons.forward

from functools import partial

import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from beartype.typing import NamedTuple, Optional, Tuple, Union
from jax import lax
from jaxtyping import Array, Complex, Float, Int, Num, jaxtyped

import ptyrodactyl.electrons as pte

jax.config.update("jax_enable_x64", True)


[docs] def transmission_func( pot_slice: Float[Array, "#a #b"], voltage_kV: Num[Array, ""] ) -> Complex[Array, ""]: """ Description ----------- Calculates the complex transmission function from a single potential slice at a given electron accelerating voltage. Because this is JAX - you assume that the input is clean, and you don't need to check for negative or NaN values. Your preprocessing steps should check for them - not the function itself. Parameters ---------- - `pot_slice` (Float[Array, "#a #b"]): potential slice in Kirkland units - `voltage_kV` (scalar_number): microscope operating voltage in kilo electronVolts Returns ------- - `trans` (Complex[Array, "#a #b"]): The transmission function of a single crystal slice Flow ---- - Calculate the electron energy in electronVolts - Calculate the wavelength in angstroms - Calculate the Einstein energy - Calculate the sigma value, which is the constant for the phase shift - Calculate the transmission function as a complex exponential """ voltage: Float[Array, ""] = jnp.multiply(voltage_kV, 1000.0) m_e: Float[Array, ""] = 9.109383e-31 # mass of an electron e_e: Float[Array, ""] = 1.602177e-19 # charge of an electron c: Float[Array, ""] = 299792458.0 # speed of light eV: Float[Array, ""] = jnp.multiply(e_e, voltage) lambda_angstrom: Float[Array, ""] = pte.wavelength_ang( voltage_kV ) # wavelength in angstroms einstein_energy = jnp.multiply(m_e, jnp.square(c)) # Einstein energy sigma: Float[Array, ""] = ( (2 * jnp.pi / (lambda_angstrom * voltage)) * (einstein_energy + eV) ) / ((2 * einstein_energy) + eV) trans: Complex[Array, "#a #b"] = jnp.exp(1j * sigma * pot_slice) return trans
[docs] def propagation_func( imsize_y: int, imsize_x: int, thickness_ang: Num[Array, ""], voltage_kV: Num[Array, ""], calib_ang: Float[Array, ""], ) -> Complex[Array, "H W"]: """ Description ----------- Calculates the complex propagation function that results in the phase shift of the exit wave when it travels from one slice to the next in the multislice algorithm Parameters ---------- - `imsize_y`, (int): Size of the image of the propagator in y axis - `imsize_x`, (int): Size of the image of the propagator in x axis - `thickness_ang`, (scalar_number): Distance between the slices in angstroms - `voltage_kV`, (scalar_number): Accelerating voltage in kilovolts - `calib_ang`, (scalar_number): Calibration or pixel size in angstroms Returns ------- - `prop` (Complex[Array, "H W"]): The propagation function of the same size given by imsize Flow ---- - Generate frequency arrays directly using fftfreq - Create 2D meshgrid of frequencies - Calculate squared sum of frequencies - Calculate wavelength - Compute the propagation function """ # Generate frequency arrays directly using fftfreq qy: Num[Array, "H"] = jnp.fft.fftfreq(imsize_y, d=calib_ang) qx: Num[Array, "W"] = jnp.fft.fftfreq(imsize_x, d=calib_ang) # Create 2D meshgrid of frequencies Lya: Num[Array, "H W"] Lxa: Num[Array, "H W"] Lya, Lxa = jnp.meshgrid(qy, qx, indexing="ij") # Calculate squared sum of frequencies L_sq: Num[Array, "H W"] = jnp.square(Lxa) + jnp.square(Lya) # Calculate wavelength lambda_angstrom: Float[Array, ""] = pte.wavelength_ang(voltage_kV) # Compute the propagation function prop: Complex[Array, "H W"] = jnp.exp( (-1j) * jnp.pi * lambda_angstrom * thickness_ang * L_sq ) return prop
[docs] def fourier_coords(calibration: float, image_size: Int[Array, "2"]) -> NamedTuple: """ Description ----------- Return the Fourier coordinates Parameters ---------- - `calibration` (float): The pixel size in angstroms in real space - `image_size`, (Int[Array, "2"]): The size of the beam in pixels Returns ------- - A NamedTuple with the following fields: - `array` (Any[Array, "* *"]): The array values - `calib_y` (float): Calibration along the first axis - `calib_x` (float): Calibration along the second axis Flow ---- - Calculate the real space field of view in y and x - Generate the inverse space array y and x - Shift the inverse space array y and x - Create meshgrid of shifted inverse space arrays - Calculate the inverse array - Calculate the calibration in y and x - Return the calibrated array """ real_fov_y: float = image_size[0] * calibration # real space field of view in y real_fov_x: float = image_size[1] * calibration # real space field of view in x inverse_arr_y: Float[Array, "H"] = ( jnp.arange((-image_size[0] / 2), (image_size[0] / 2), 1) ) / real_fov_y # inverse space array y inverse_arr_x: Float[Array, "W"] = ( jnp.arange((-image_size[1] / 2), (image_size[1] / 2), 1) ) / real_fov_x # inverse space array x shifter_y: float = image_size[0] // 2 shifter_x: float = image_size[1] // 2 inverse_shifted_y: Float[Array, "H"] = jnp.roll( inverse_arr_y, shifter_y ) # shifted inverse space array y inverse_shifted_x: Float[Array, "W"] = jnp.roll( inverse_arr_x, shifter_x ) # shifted inverse space array y inverse_xx: Float[Array, "H W"] inverse_yy: Float[Array, "H W"] inverse_xx, inverse_yy = jnp.meshgrid(inverse_shifted_x, inverse_shifted_y) inv_squared = jnp.multiply(inverse_yy, inverse_yy) + jnp.multiply( inverse_xx, inverse_xx ) inverse_array: Float[Array, "H W"] = inv_squared**0.5 calib_inverse_y: float = inverse_arr_y[1] - inverse_arr_y[0] calib_inverse_x: float = inverse_arr_x[1] - inverse_arr_x[0] calibrated_array = NamedTuple( "array_with_calibrations", [("array", Num[Array, "* *"]), ("calib_y", float), ("calib_x", float)], ) return calibrated_array(inverse_array, calib_inverse_y, calib_inverse_x)
[docs] def fourier_calib( real_space_calib: float | Float[Array, "*"], sizebeam: Int[Array, "2"], ) -> Float[Array, "2"]: """ Description ----------- Generate the Fourier calibration for the beam Parameters ---------- - `real_space_calib` (float | Float[Array, "*"]): The pixel size in angstroms in real space - `sizebeam` (Int[Array, "2"]): The size of the beam in pixels Returns ------- - `inverse_space_calib` (Float[Array, "2"]): The Fourier calibration in angstroms Flow ---- - Calculate the field of view in real space - Calculate the inverse space calibration """ field_of_view: Float[Array, "*"] = jnp.multiply( jnp.float64(sizebeam), real_space_calib ) inverse_space_calib = 1 / field_of_view return inverse_space_calib
[docs] @jax.jit def make_probe( aperture: Num[Array, ""], voltage: Union[float, int], image_size: Int[Array, "2"], calibration_pm: float, defocus: float = 0, c3: float = 0, c5: float = 0, ) -> Complex[Array, "H W"]: """ Description ----------- This calculates an electron probe based on the size and the estimated Fourier co-ordinates with the option of adding spherical aberration in the form of defocus, C3 and C5 Parameters ---------- - `aperture` (Union[float, int]): The aperture size in milliradians - `voltage` (Union[float, int]): The microscope accelerating voltage in kilo electronVolts - `image_size`, (Int[Array, "2"]): The size of the beam in pixels - `calibration_pm` (float): The calibration in picometers - `defocus` (float): The defocus value in angstroms - `c3` (float): The C3 value in angstroms - `c5` (float): The C5 value in angstroms Returns ------- - `probe_real_space` (Complex[Array, "H W"]): The calculated electron probe in real space Flow ---- - Convert the aperture to radians - Calculate the wavelength in angstroms - Calculate the maximum L value - Calculate the field of view in x and y - Generate the inverse space array y and x - Shift the inverse space array y and x - Create meshgrid of shifted inverse space arrays - Calculate the inverse array - Calculate the calibration in y and x - Calculate the probe in real space """ aperture = aperture / 1000 wavelength = pte.wavelength_ang(voltage) LMax = aperture / wavelength image_y, image_x = image_size x_FOV = image_x * 0.01 * calibration_pm y_FOV = image_y * 0.01 * calibration_pm qx = (jnp.arange((-image_x / 2), (image_x / 2), 1)) / x_FOV x_shifter = image_x // 2 qy = (jnp.arange((-image_y / 2), (image_y / 2), 1)) / y_FOV y_shifter = image_y // 2 Lx = jnp.roll(qx, x_shifter) Ly = jnp.roll(qy, y_shifter) Lya, Lxa = jnp.meshgrid(Lx, Ly) L2 = jnp.multiply(Lxa, Lxa) + jnp.multiply(Lya, Lya) inverse_real_matrix = L2**0.5 Adist = jnp.asarray(inverse_real_matrix <= LMax, dtype=jnp.complex128) chi_probe = pte.aberration(inverse_real_matrix, wavelength, defocus, c3, c5) Adist *= jnp.exp(-1j * chi_probe) probe_real_space = jnp.fft.ifftshift(jnp.fft.ifft2(Adist)) return probe_real_space
[docs] @jax.jit def aberration( fourier_coord: Float[Array, "H W"], lambda_angstrom: Num[Array, ""], defocus: Optional[Float[Array, ""]] = jnp.asarray(0.0), c3: Optional[Float[Array, ""]] = jnp.asarray(0.0), c5: Optional[Float[Array, ""]] = jnp.asarray(0.0), ) -> Float[Array, "H W"]: """ Description ----------- This calculates the aberration function for the electron probe based on the Fourier co-ordinates Parameters ---------- - `fourier_coord` (Float[Array, "H W"]): The Fourier co-ordinates - `lambda_angstrom` (Num[Array, ""]): The wavelength in angstroms - `defocus` (Float[Array, ""]): The defocus value in angstroms. Default is 0.0 - `c3` (Float[Array, ""]): The C3 value in angstroms. Default is 0.0 - `c5` (Float[Array, ""]): The C5 value in angstroms. Default is 0.0 Returns ------- - `chi_probe` (Float[Array, "H W"]): The calculated aberration function Flow ---- - Calculate the phase shift - Calculate the chi value - Calculate the chi probe value """ p_matrix = lambda_angstrom * fourier_coord chi: Float[Array, "H W"] = ( ((defocus * jnp.power(p_matrix, 2)) / 2) + ((c3 * (1e7) * jnp.power(p_matrix, 4)) / 4) + ((c5 * (1e7) * jnp.power(p_matrix, 6)) / 6) ) chi_probe: Float[Array, "H W"] = (2 * jnp.pi * chi) / lambda_angstrom return chi_probe
[docs] @jaxtyped(typechecker=typechecker) def wavelength_ang(voltage_kV: Num[Array, "#a"]) -> Float[Array, "#a"]: """ Description ----------- Calculates the relativistic electron wavelength in angstroms based on the microscope accelerating voltage. Because this is JAX - you assume that the input is clean, and you don't need to check for negative or NaN values. Your preprocessing steps should check for them - not the function itself. Parameters ---------- - `voltage_kV` (num_type | Float[Array, "#a"]): The microscope accelerating voltage in kilo electronVolts Returns ------- - `in_angstroms (Float[Array, "*"]): The electron wavelength in angstroms Flow ---- - Calculate the electron wavelength in meters - Convert the wavelength to angstroms """ m: float = 9.109383e-31 # mass of an electron e: float = 1.602177e-19 # charge of an electron c: float = 299792458.0 # speed of light h: float = 6.62607e-34 # Planck's constant eV: Float[Array, "#a"] = ( jnp.float64(voltage_kV) * jnp.float64(1000.0) * jnp.float64(e) ) numerator: Float[Array, ""] = jnp.float64(h * c) denominator: Float[Array, "#a"] = jnp.multiply(eV, ((2 * m * jnp.square(c)) + eV)) wavelength_meters: Float[Array, "#a"] = jnp.sqrt( numerator / denominator ) # in meters lambda_angstroms: Float[Array, "#a"] = 1e10 * wavelength_meters # in angstroms return lambda_angstroms
[docs] def cbed( pot_slice: Complex[Array, "H W #S"], beam: Complex[Array, "H W #M"], slice_thickness: Float[Array, ""], voltage_kV: Float[Array, "#v"], calib_ang: Float[Array, ""], ) -> Float[Array, "H W"]: """ Description ----------- Calculates the CBED pattern for single/multiple slices and single/multiple beam modes. This function computes the Convergent Beam Electron Diffraction (CBED) pattern by propagating one or more beam modes through one or more potential slices. Parameters ---------- - `pot_slice` (Complex[Array, "H W #S"]): The potential slice(s). H and W are height and width, S is the number of slices (optional). - `beam` (Complex[Array, "H W #M"]): The electron beam mode(s). M is the number of modes (optional). - `slice_thickness` (Float[Array, ""]): The thickness of each slice in angstroms. - `voltage_kV` (Float[Array, "#v"]): The accelerating voltage(s) in kilovolts. - `calib_ang` (Float[Array, ""]): The calibration in angstroms. Returns ------- - `cbed_pattern` (Float[Array, "H W"]): The calculated CBED pattern. Flow ---- - Ensure 3D arrays even for single slice/mode - Calculate the transmission function for a single slice - Initialize the convolution state - Scan over all slices - Compute the Fourier transform - Compute the intensity for each mode - Sum the intensities across all modes. """ # Ensure 3D arrays even for single slice/mode pot_slice: Complex[Array, "H W #S"] = jnp.atleast_3d(pot_slice) beam: Complex[Array, "H W #M"] = jnp.atleast_3d(beam) # Calculate the transmission function for a single slice slice_transmission: Complex[Array, "H W"] = pte.propagation_func( beam.shape[0], beam.shape[1], slice_thickness, voltage_kV, calib_ang ) def process_slice( carry: Tuple[Complex[Array, "H W *M"], Complex[Array, "H W *M"]], pot_slice_i ): convolve, _ = carry this_slice = jnp.multiply(pot_slice_i[:, :, None], beam) propagated_slice = jnp.fft.ifft2( jnp.multiply(jnp.fft.fft2(this_slice), slice_transmission[:, :, None]) ) new_convolve = jnp.multiply(convolve, propagated_slice) return (new_convolve, beam), None # Initialize the convolution state initial_carry = (jnp.ones_like(beam, dtype=jnp.complex128), beam) # Scan over all slices (real_space_convolve, _), _ = lax.scan( process_slice, initial_carry, pot_slice.transpose(2, 0, 1) ) # Compute the Fourier transform fourier_space_modes = jnp.fft.fftshift( jnp.fft.fft2(real_space_convolve), axes=(0, 1) ) # Compute the intensity for each mode cbed_mode: Float[Array, "H W M"] = jnp.square(jnp.abs(fourier_space_modes)) # Sum the intensities across all modes cbed_pattern: Float[Array, "H W"] = jnp.sum(cbed_mode, axis=-1) return cbed_pattern
[docs] def cbed_no_slice( pot_slice: Complex[Array, "H W *S"], beam: Complex[Array, "H W *M"], slice_transmission: Complex[Array, "*"], ) -> Float[Array, "H W"]: """ Description ----------- Calculates the CBED pattern for single/multiple slices and single/multiple beam modes. This function computes the Convergent Beam Electron Diffraction (CBED) pattern by propagating one or more beam modes through one or more potential slices. This version takes in a pre-calculated transmission function for going from one slice to the next, which is useful for calculating the CBED pattern multiple times where the transmission function remains the same., example is 4D-STEM. Parameters ---------- - `pot_slice` (Complex[Array, "H W *S"]): The potential slice(s). H and W are height and width, S is the number of slices (optional). - `beam` (Complex[Array, "H W *M"]): The electron beam mode(s). M is the number of modes (optional). - `slice_transmission` (Complex[Array, "*"]): The pre-calculated transmission function for going from one slice to the next. Returns ------- - `cbed_pattern` (Float[Array, "H W"]): The calculated CBED pattern. Flow ---- - Ensure 3D arrays even for single slice/mode - Initialize the convolution state - Scan over all slices - Compute the Fourier transform - Compute the intensity for each mode - Sum the intensities across all modes """ # Ensure 3D arrays even for single slice/mode pot_slice: Complex[Array, "H W *S"] = jnp.atleast_3d(pot_slice) beam: Complex[Array, "H W *M"] = jnp.atleast_3d(beam) def process_slice( carry: Tuple[Complex[Array, "H W *M"], Complex[Array, "H W *M"]], pot_slice_i ): convolve, _ = carry this_slice = jnp.multiply(pot_slice_i[:, :, None], beam) propagated_slice = jnp.fft.ifft2( jnp.multiply(jnp.fft.fft2(this_slice), slice_transmission[:, :, None]) ) new_convolve = jnp.multiply(convolve, propagated_slice) return (new_convolve, beam), None # Initialize the convolution state initial_carry = (jnp.ones_like(beam, dtype=jnp.complex128), beam) # Scan over all slices (real_space_convolve, _), _ = lax.scan( process_slice, initial_carry, pot_slice.transpose(2, 0, 1) ) # Compute the Fourier transform fourier_space_modes = jnp.fft.fftshift( jnp.fft.fft2(real_space_convolve), axes=(0, 1) ) # Compute the intensity for each mode cbed_mode: Float[Array, "H W M"] = jnp.square(jnp.abs(fourier_space_modes)) # Sum the intensities across all modes cbed_pattern: Float[Array, "H W"] = jnp.sum(cbed_mode, axis=-1) return cbed_pattern
[docs] def shift_beam_fourier( beam: Complex[Array, "H W M"], pos: Float[Array, "... 2"], calib_ang: Float[Array, "*"], ) -> Complex[Array, "... H W M"]: """ Description ----------- Shifts the beam to new position(s) using Fourier shifting. Parameters ---------- - beam (Complex[Array, "H W M"]): The electron beam modes. - pos (Float[Array, "... 2"]): The (y, x) position(s) to shift to in pixels. Can be a single position [2] or multiple [..., 2]. - calib_ang (Float[Array, "*"]): The calibration in angstroms. Returns ------- - shifted_beams (Complex[Array, "... H W M"]): The shifted beam(s) for all position(s) and mode(s). Flow ---- - Convert positions from real space to Fourier space - Create phase ramps in Fourier space for all positions - Apply shifts to each mode for all positions """ H, W, _ = beam.shape # Ensure pos is at least 2D, even for a single position pos = jnp.atleast_2d(pos) # Convert positions from real space to Fourier space fy = pos[..., 0] / (H * calib_ang) fx = pos[..., 1] / (W * calib_ang) # Create phase ramps in Fourier space for all positions y, x = jnp.meshgrid(jnp.fft.fftfreq(H), jnp.fft.fftfreq(W), indexing="ij") phase_ramps = jnp.exp( -2j * jnp.pi * (fy[..., None, None] * y + fx[..., None, None] * x) ) # Apply shifts to each mode for all positions fft_beam = jnp.fft.fft2(beam, axes=(0, 1)) shifted_fft_beams = fft_beam * phase_ramps[..., None] shifted_beams = jnp.fft.ifft2(shifted_fft_beams, axes=(-3, -2)) return shifted_beams
[docs] @partial(jax.jit, static_argnames=["slice_thickness", "voltage_kV", "calib_ang"]) def stem_4d( pot_slice: Complex[Array, "H W S"], beam: Complex[Array, "H W M"], pos_list: Float[Array, "P 2"], slice_thickness: float, voltage_kV: float, calib_ang: float, ) -> Float[Array, "P H W"]: """ Description ----------- Calculates the 4D-STEM pattern for multiple probe positions with sharding. Parameters ---------- - `pot_slice` (Complex[Array, "H W S"]): The potential slices. - `beam` (Complex[Array, "H W M"]): The electron beam modes. - `pos_list` (Float[Array, "P 2"]): List of (y, x) probe positions in pixels. - `slice_thickness` (float): The thickness of each slice in angstroms. - `voltage_kV` (float): The accelerating voltage in kilovolts. - `calib_ang` (float): The calibration in angstroms. - `mesh` (Mesh): The device mesh for sharding. Returns ------- - stem_pattern (Float[Array, "P H W"]): The calculated 4D-STEM pattern. Flow ---- - Calculate the transmission function once - Shift the beam to all positions - Calculate CBED patterns for all positions """ # Calculate the transmission function once slice_transmission: Complex[Array, "H W"] = pte.propagation_func( beam.shape[0], beam.shape[1], slice_thickness, voltage_kV, calib_ang ) # Shift the beam to all positions shifted_beams = pte.shift_beam_fourier(beam, pos_list, calib_ang) # Calculate CBED patterns for all positions def calc_cbed(electron_beam): return cbed_no_slice(pot_slice, electron_beam, slice_transmission) stem_pattern = jax.vmap(calc_cbed, in_axes=0, out_axes=0)(shifted_beams) return stem_pattern
[docs] def stem_4d_multi( pot_slices: Complex[Array, "H W S"], beam: Complex[Array, "H W M"], pos_list: Float[Array, "P 2"], slice_thickness: Float[Array, ""], voltage_kV: Float[Array, ""], calib_ang: Float[Array, ""], ) -> Float[Array, "P H W"]: """ Calculates 4D-STEM pattern for multiple slices. This function propagates the beam through multiple slices before calculating the final diffraction pattern. """ # Calculate transmission function once (same for all slices) slice_transmission = pte.propagation_func( beam.shape[0], beam.shape[1], slice_thickness, voltage_kV, calib_ang ) # Shift beam to all positions shifted_beams = pte.shift_beam_fourier(beam, pos_list, calib_ang) # Calculate patterns for all positions def calc_multi_slice_cbed(electron_beam): # Start with initial beam wave = electron_beam # Propagate through all slices for i in range(pot_slices.shape[-1]): # Apply transmission function for current slice wave = wave * pot_slices[..., i : i + 1] # Propagate to next slice if i < pot_slices.shape[-1] - 1: # Don't propagate after last slice wave = jnp.fft.ifft2(jnp.fft.fft2(wave) * slice_transmission[..., None]) # Calculate final diffraction pattern fourier = jnp.fft.fftshift(jnp.fft.fft2(wave), axes=(0, 1)) intensity = jnp.sum(jnp.abs(fourier) ** 2, axis=-1) return intensity # Apply to all probe positions stem_pattern = jax.vmap(calc_multi_slice_cbed, in_axes=0, out_axes=0)(shifted_beams) return stem_pattern
[docs] def stem_4d_mixed_state( pot_slices: Complex[Array, "H W S"], modes: Complex[Array, "H W M"], mode_weights: Float[Array, "M"], pos_list: Float[Array, "P 2"], slice_thickness: Float[Array, ""], voltage_kV: Float[Array, ""], calib_ang: Float[Array, ""], ) -> Float[Array, "P H W"]: """ Calculates 4D-STEM pattern for multiple modes and slices with mixed state. """ # Calculate transmission function slice_transmission = pte.propagation_func( modes.shape[0], modes.shape[1], slice_thickness, voltage_kV, calib_ang ) # Shift all modes to all positions shifted_modes = pte.shift_beam_fourier(modes, pos_list, calib_ang) def calc_mixed_state_cbed(shifted_mode_set): # Process each mode def process_mode(mode): wave = mode # Propagate through slices for i in range(pot_slices.shape[-1]): wave = wave * pot_slices[..., i] if i < pot_slices.shape[-1] - 1: wave = jnp.fft.ifft2(jnp.fft.fft2(wave) * slice_transmission) return wave # Apply to all modes waves = jax.vmap(process_mode)(shifted_mode_set) # Calculate diffraction patterns for all modes fourier = jnp.fft.fftshift(jnp.fft.fft2(waves), axes=(0, 1)) mode_intensities = jnp.abs(fourier) ** 2 # Weight and sum the mode intensities weighted_sum = jnp.sum(mode_intensities * mode_weights[:, None, None], axis=0) return weighted_sum # Apply to all probe positions stem_pattern = jax.vmap(calc_mixed_state_cbed)(shifted_modes) return stem_pattern
[docs] def initialize_random_modes( shape: Tuple[int, int], num_modes: int, dtype=jnp.complex128 ) -> Complex[Array, "H W M"]: """Initialize random orthogonal modes.""" key = jax.random.PRNGKey(0) modes = jax.random.normal(key, (shape[0], shape[1], num_modes), dtype=dtype) # Orthogonalize modes using QR decomposition modes_flat = modes.reshape(-1, num_modes) q, r = jnp.linalg.qr(modes_flat) modes = q.reshape(shape[0], shape[1], num_modes) return modes
[docs] def normalize_mode_weights(weights: Float[Array, "M"]) -> Float[Array, "M"]: """Normalize mode weights to sum to 1.""" return weights / jnp.sum(weights)