Source code for ptyrodactyl.electrons.inverse

import jax
import jax.numpy as jnp
from beartype.typing import Callable, Dict, NamedTuple, Optional, Tuple
from jaxtyping import Array, Complex, Float

import ptyrodactyl.electrons as pte
import ptyrodactyl.tools as ptt

OPTIMIZERS: Dict[str, ptt.Optimizer] = {
    "adam": ptt.Optimizer(ptt.init_adam, ptt.adam_update),
    "adagrad": ptt.Optimizer(ptt.init_adagrad, ptt.adagrad_update),
    "rmsprop": ptt.Optimizer(ptt.init_rmsprop, ptt.rmsprop_update),
}


[docs] def get_optimizer(optimizer_name: str) -> ptt.Optimizer: if optimizer_name not in OPTIMIZERS: raise ValueError(f"Unknown optimizer: {optimizer_name}") return OPTIMIZERS[optimizer_name]
[docs] def single_slice_ptychography( experimental_4dstem: Float[Array, "P H W"], initial_pot_slice: Complex[Array, "H W"], initial_beam: Complex[Array, "H W"], pos_list: Float[Array, "P 2"], slice_thickness: Float[Array, "*"], voltage_kV: Float[Array, "*"], calib_ang: Float[Array, "*"], num_iterations: int = 1000, learning_rate: float = 0.001, loss_type: str = "mse", optimizer_name: str = "adam", ) -> Tuple[Complex[Array, "H W"], Complex[Array, "H W"]]: """ Create and run an optimization routine for 4D-STEM reconstruction. Args: - experimental_4dstem (Float[Array, "P H W"]): Experimental 4D-STEM data. - initial_pot_slice (Complex[Array, "H W"]): Initial guess for potential slice. - initial_beam (Complex[Array, "H W"]): Initial guess for electron beam. - pos_list (Float[Array, "P 2"]): List of probe positions. - slice_thickness (Float[Array, "*"]): Thickness of each slice. - voltage_kV (Float[Array, "*"]): Accelerating voltage. - calib_ang (Float[Array, "*"]): Calibration in angstroms. - devices (jax.Array): Array of devices for sharding. - num_iterations (int): Number of optimization iterations. - learning_rate (float): Learning rate for optimization. - loss_type (str): Type of loss function to use. Returns: - Tuple[Complex[Array, "H W"], Complex[Array, "H W"]]: Optimized potential slice and beam. """ # Create the forward function def forward_fn(pot_slice, beam): return pte.stem_4d( pot_slice[None, ...], beam[None, ...], pos_list, slice_thickness, voltage_kV, calib_ang, ) # Create the loss function loss_func = ptt.create_loss_function(forward_fn, experimental_4dstem, loss_type) # Create a function that returns both loss and gradients @jax.jit def loss_and_grad( pot_slice: Complex[Array, "H W"], beam: Complex[Array, "H W"] ) -> Tuple[Float[Array, ""], Dict[str, Complex[Array, "H W"]]]: loss, grads = jax.value_and_grad(loss_func, argnums=(0, 1))(pot_slice, beam) return loss, {"pot_slice": grads[0], "beam": grads[1]} optimizer = get_optimizer(optimizer_name) pot_slice_state = optimizer.init(initial_pot_slice.shape) beam_state = optimizer.init(initial_beam.shape) pot_slice = initial_pot_slice beam = initial_beam @jax.jit def update_step(pot_slice, beam, pot_slice_state, beam_state): loss, grads = loss_and_grad(pot_slice, beam) pot_slice, pot_slice_state = optimizer.update( pot_slice, grads["pot_slice"], pot_slice_state, learning_rate ) beam, beam_state = optimizer.update( beam, grads["beam"], beam_state, learning_rate ) return pot_slice, beam, pot_slice_state, beam_state, loss for i in range(num_iterations): pot_slice, beam, pot_slice_state, beam_state, loss = update_step( pot_slice, beam, pot_slice_state, beam_state ) if i % 100 == 0: print(f"Iteration {i}, Loss: {loss}") return pot_slice, beam
[docs] def single_slice_poscorrected( experimental_4dstem: Float[Array, "P H W"], initial_pot_slice: Complex[Array, "H W"], initial_beam: Complex[Array, "H W"], initial_pos_list: Float[Array, "P 2"], slice_thickness: Float[Array, "*"], voltage_kV: Float[Array, "*"], calib_ang: Float[Array, "*"], devices: jax.Array, num_iterations: int = 1000, learning_rate: float = 0.001, pos_learning_rate: float = 0.1, # Separate learning rate for positions loss_type: str = "mse", optimizer_name: str = "adam", ) -> Tuple[Complex[Array, "H W"], Complex[Array, "H W"], Float[Array, "P 2"]]: """ Create and run an optimization routine for 4D-STEM reconstruction with position correction. Args: - `experimental_4dstem` (Float[Array, "P H W"]): Experimental 4D-STEM data. - `initial_pot_slice` (Complex[Array, "H W"]): Initial guess for potential slice. - `initial_beam` (Complex[Array, "H W"]): Initial guess for electron beam. - `initial_pos_list` (Float[Array, "P 2"]): Initial list of probe positions. - `slice_thickness` (Float[Array, "*"]): Thickness of each slice. - `voltage_kV` (Float[Array, "*"]): Accelerating voltage. - `calib_ang` (Float[Array, "*"]): Calibration in angstroms. - `devices` (jax.Array): Array of devices for sharding. - `num_iterations` (int): Number of optimization iterations. - `learning_rate` (float): Learning rate for potential slice and beam optimization. - `pos_learning_rate` (float): Learning rate for position optimization. - `loss_type` (str): Type of loss function to use. Returns: - Tuple[Complex[Array, "H W"], Complex[Array, "H W"], Float[Array, "P 2"]]: Optimized potential slice, beam, and corrected positions. """ # Create the forward function def forward_fn(pot_slice, beam, pos_list): return pte.stem_4d( pot_slice[None, ...], beam[None, ...], pos_list, slice_thickness, voltage_kV, calib_ang, devices, ) # Create the loss function loss_func = ptt.create_loss_function(forward_fn, experimental_4dstem, loss_type) # Create a function that returns both loss and gradients @jax.jit def loss_and_grad( pot_slice: Complex[Array, "H W"], beam: Complex[Array, "H W"], pos_list: Float[Array, "P 2"], ) -> Tuple[Float[Array, ""], Dict[str, Array]]: loss, grads = jax.value_and_grad(loss_func, argnums=(0, 1, 2))( pot_slice, beam, pos_list ) return loss, {"pot_slice": grads[0], "beam": grads[1], "pos_list": grads[2]} optimizer = get_optimizer(optimizer_name) pot_slice_state = optimizer.init(initial_pot_slice.shape) beam_state = optimizer.init(initial_beam.shape) pos_state = optimizer.init(initial_pos_list.shape) # ... [rest of the function remains the same, just update the optimizer calls] ... @jax.jit def update_step(pot_slice, beam, pos_list, pot_slice_state, beam_state, pos_state): loss, grads = loss_and_grad(pot_slice, beam, pos_list) pot_slice, pot_slice_state = optimizer.update( pot_slice, grads["pot_slice"], pot_slice_state, learning_rate ) beam, beam_state = optimizer.update( beam, grads["beam"], beam_state, learning_rate ) pos_list, pos_state = optimizer.update( pos_list, grads["pos_list"], pos_state, pos_learning_rate ) return pot_slice, beam, pos_list, pot_slice_state, beam_state, pos_state, loss pot_slice = initial_pot_slice beam = initial_beam pos_list = initial_pos_list for i in range(num_iterations): pot_slice, beam, pos_list, pot_slice_state, beam_state, pos_state, loss = ( update_step( pot_slice, beam, pos_list, pot_slice_state, beam_state, pos_state ) ) if i % 100 == 0: print(f"Iteration {i}, Loss: {loss}") return pot_slice, beam, pos_list
[docs] def multi_slice_ptychography( experimental_4dstem: Float[Array, "P H W"], initial_pot_slices: Complex[Array, "H W S"], # S is number of slices initial_beam: Complex[Array, "H W"], pos_list: Float[Array, "P 2"], slice_thickness: Float[Array, ""], voltage_kV: Float[Array, ""], calib_ang: Float[Array, ""], num_iterations: int = 1000, learning_rate: float = 0.001, loss_type: str = "mse", optimizer_name: str = "adam", scheduler_fn: Optional[ptt.SchedulerFn] = None, ) -> Tuple[Complex[Array, "H W S"], Complex[Array, "H W"]]: """ Multi-slice ptychographic reconstruction. Args: experimental_4dstem: Experimental 4D-STEM data initial_pot_slices: Initial guess for potential slices initial_beam: Initial guess for electron beam pos_list: List of probe positions slice_thickness: Thickness of each slice voltage_kV: Accelerating voltage calib_ang: Calibration in angstroms num_iterations: Number of optimization iterations learning_rate: Initial learning rate loss_type: Type of loss function optimizer_name: Name of optimizer to use scheduler_fn: Optional learning rate scheduler Returns: Tuple of optimized potential slices and beam """ # Create the forward function for multiple slices def forward_fn(pot_slices: Complex[Array, "H W S"], beam: Complex[Array, "H W"]): return pte.stem_4d_multi( pot_slices, beam[None, ...], pos_list, slice_thickness, voltage_kV, calib_ang, ) # Create the loss function loss_func = ptt.create_loss_function(forward_fn, experimental_4dstem, loss_type) # Get loss and gradients @jax.jit def loss_and_grad( pot_slices: Complex[Array, "H W S"], beam: Complex[Array, "H W"] ) -> Tuple[Float[Array, ""], Dict[str, Array]]: loss, grads = jax.value_and_grad(loss_func, argnums=(0, 1))(pot_slices, beam) return loss, {"pot_slices": grads[0], "beam": grads[1]} # Initialize optimizer optimizer = get_optimizer(optimizer_name) pot_slices_state = optimizer.init(initial_pot_slices.shape) beam_state = optimizer.init(initial_beam.shape) # Initialize scheduler if provided if scheduler_fn is not None: scheduler_state = ptt.init_scheduler_state(learning_rate) # Initialize variables pot_slices = initial_pot_slices beam = initial_beam current_lr = learning_rate @jax.jit def update_step(pot_slices, beam, pot_slices_state, beam_state, lr): loss, grads = loss_and_grad(pot_slices, beam) pot_slices, pot_slices_state = optimizer.update( pot_slices, grads["pot_slices"], pot_slices_state, lr ) beam, beam_state = optimizer.update(beam, grads["beam"], beam_state, lr) return pot_slices, beam, pot_slices_state, beam_state, loss for i in range(num_iterations): # Update learning rate if scheduler is provided if scheduler_fn is not None: current_lr, scheduler_state = scheduler_fn(scheduler_state) # Perform optimization step pot_slices, beam, pot_slices_state, beam_state, loss = update_step( pot_slices, beam, pot_slices_state, beam_state, current_lr ) if i % 100 == 0: print(f"Iteration {i}, Loss: {loss}, LR: {current_lr}") return pot_slices, beam
[docs] def initialize_probe_modes( base_probe: Complex[Array, "H W"], num_modes: int, random_seed: int = 42 ) -> pte.ProbeState: """ Initialize multiple probe modes from a base probe. Args: base_probe: Base probe function num_modes: Number of modes to generate random_seed: Random seed for initialization Returns: ProbeState with initialized modes and weights """ key = jax.random.PRNGKey(random_seed) # Initialize modes with small random perturbations of base probe perturbations = ( jax.random.normal( key, shape=(base_probe.shape[0], base_probe.shape[1], num_modes), dtype=base_probe.dtype, ) * 0.1 ) modes = jnp.tile(base_probe[..., None], (1, 1, num_modes)) + perturbations # Initialize weights with decreasing values weights = jnp.exp(-jnp.arange(num_modes, dtype=jnp.float32)) weights = weights / jnp.sum(weights) # Normalize return pte.ProbeState(modes=modes, weights=weights)
def multi_mode_ptychography( experimental_4dstem: Float[Array, "P H W"], initial_pot_slices: Complex[Array, "H W S"], initial_probe_state: pte.ProbeState, pos_list: Float[Array, "P 2"], slice_thickness: Float[Array, ""], voltage_kV: Float[Array, ""], calib_ang: Float[Array, ""], num_iterations: int = 1000, learning_rate: float = 0.001, weight_learning_rate: float = 0.0001, # Separate LR for weights loss_type: str = "mse", optimizer_name: str = "adam", scheduler_fn: Optional[ptt.SchedulerFn] = None, ) -> Tuple[Complex[Array, "H W S"], pte.ProbeState]: """ Multi-mode ptychographic reconstruction. Args: experimental_4dstem: Experimental 4D-STEM data initial_pot_slices: Initial guess for potential slices initial_probe_state: Initial probe modes and weights pos_list: List of probe positions slice_thickness: Thickness of each slice voltage_kV: Accelerating voltage calib_ang: Calibration in angstroms num_iterations: Number of optimization iterations learning_rate: Initial learning rate weight_learning_rate: Learning rate for mode weights loss_type: Type of loss function optimizer_name: Name of optimizer to use scheduler_fn: Optional learning rate scheduler Returns: Tuple of optimized potential slices and probe state """ def forward_fn( pot_slices: Complex[Array, "H W S"], probe_modes: Complex[Array, "H W M"], mode_weights: Float[Array, "M"], ): # Calculate pattern for each mode patterns = pte.stem_4d_multi( pot_slices, probe_modes, pos_list, slice_thickness, voltage_kV, calib_ang, ) # Weight patterns by mode occupations weighted_sum = jnp.sum( patterns[..., None] * mode_weights[None, None, None, :], axis=-1 ) return weighted_sum # Create loss function loss_func = ptt.create_loss_function(forward_fn, experimental_4dstem, loss_type) @jax.jit def loss_and_grad( pot_slices: Complex[Array, "H W S"], probe_state: pte.ProbeState, ) -> Tuple[Float[Array, ""], dict]: loss, grads = jax.value_and_grad( lambda p, m, w: loss_func(p, m, w), argnums=(0, 1, 2) )(pot_slices, probe_state.modes, probe_state.weights) return loss, {"pot_slices": grads[0], "modes": grads[1], "weights": grads[2]} # Initialize optimizers optimizer = get_optimizer(optimizer_name) pot_slices_state = optimizer.init(initial_pot_slices.shape) modes_state = optimizer.init(initial_probe_state.modes.shape) weights_state = optimizer.init(initial_probe_state.weights.shape) if scheduler_fn is not None: scheduler_state = ptt.init_scheduler_state(learning_rate) pot_slices = initial_pot_slices probe_state = initial_probe_state current_lr = learning_rate @jax.jit def update_step(pot_slices, probe_state, opt_states, lr, weight_lr): pot_slices_state, modes_state, weights_state = opt_states loss, grads = loss_and_grad(pot_slices, probe_state) # Update potential slices and modes pot_slices, pot_slices_state = optimizer.update( pot_slices, grads["pot_slices"], pot_slices_state, lr ) modes, modes_state = optimizer.update( probe_state.modes, grads["modes"], modes_state, lr ) # Update weights with separate learning rate weights, weights_state = optimizer.update( probe_state.weights, grads["weights"], weights_state, weight_lr ) # Normalize weights weights = jnp.abs(weights) # Ensure positive weights = weights / jnp.sum(weights) # Normalize new_probe_state = pte.ProbeState(modes=modes, weights=weights) new_opt_states = (pot_slices_state, modes_state, weights_state) return pot_slices, new_probe_state, new_opt_states, loss for i in range(num_iterations): if scheduler_fn is not None: current_lr, scheduler_state = scheduler_fn(scheduler_state) opt_states = (pot_slices_state, modes_state, weights_state) pot_slices, probe_state, opt_states, loss = update_step( pot_slices, probe_state, opt_states, current_lr, weight_learning_rate ) pot_slices_state, modes_state, weights_state = opt_states if i % 100 == 0: print(f"Iteration {i}, Loss: {loss}, LR: {current_lr}") print(f"Mode weights: {probe_state.weights}") return pot_slices, probe_state
[docs] def initialize_mixed_states( base_state: Complex[Array, "H W"], num_states: int, energy_spread: float = 0.5, # eV random_seed: int = 42, ) -> pte.MixedQuantumStates: """ Initialize mixed states for partial temporal coherence. Args: base_state: Base state (e.g., probe) num_states: Number of states in mixture energy_spread: Energy spread in eV (FWHM) random_seed: Random seed for initialization Returns: MixedState with initialized states and probabilities """ key = jax.random.PRNGKey(random_seed) # Generate energy offsets with Gaussian distribution sigma = energy_spread / (2.355) # Convert FWHM to sigma energies = jax.random.normal(key, shape=(num_states,)) * sigma # Create states with phase variations phase_factors = jnp.exp(1j * energies[:, None, None]) states = base_state[None, ...] * phase_factors # Calculate probabilities (Gaussian distribution) probabilities = jnp.exp(-(energies**2) / (2 * sigma**2)) probabilities = probabilities / jnp.sum(probabilities) return pte.MixedQuantumStates( states=states.transpose(1, 2, 0), probabilities=probabilities )
[docs] def multi_mode_ptychography( experimental_4dstem: Float[Array, "P H W"], initial_pot_slices: Complex[Array, "H W S"], initial_modes: Complex[Array, "H W M"], # M different probe modes mode_weights: Float[Array, "M"], # Weights for each mode pos_list: Float[Array, "P 2"], slice_thickness: Float[Array, ""], voltage_kV: Float[Array, ""], calib_ang: Float[Array, ""], num_iterations: int = 1000, learning_rate: float = 0.001, loss_type: str = "mse", optimizer_name: str = "adam", scheduler_fn: Optional[ptt.SchedulerFn] = None, ) -> Tuple[Complex[Array, "H W S"], Complex[Array, "H W M"], Float[Array, "M"]]: """ Multi-mode ptychographic reconstruction with optional mixed state. Args: experimental_4dstem: Experimental 4D-STEM data initial_pot_slices: Initial guess for potential slices initial_modes: Initial guess for probe modes mode_weights: Initial weights for each mode pos_list: List of probe positions slice_thickness: Thickness of each slice voltage_kV: Accelerating voltage calib_ang: Calibration in angstroms num_iterations: Number of optimization iterations learning_rate: Initial learning rate loss_type: Type of loss function optimizer_name: Name of optimizer to use scheduler_fn: Optional learning rate scheduler Returns: Tuple of optimized potential slices, probe modes, and mode weights """ # Normalize mode weights mode_weights = mode_weights / jnp.sum(mode_weights) def forward_fn( pot_slices: Complex[Array, "H W S"], modes: Complex[Array, "H W M"], weights: Float[Array, "M"], ) -> Float[Array, "P H W"]: return pte.stem_4d_mixed_state( pot_slices, modes, weights, pos_list, slice_thickness, voltage_kV, calib_ang, ) # Create loss function loss_func = ptt.create_loss_function(forward_fn, experimental_4dstem, loss_type) @jax.jit def loss_and_grad( pot_slices: Complex[Array, "H W S"], modes: Complex[Array, "H W M"], weights: Float[Array, "M"], ) -> Tuple[Float[Array, ""], Dict[str, Array]]: loss, grads = jax.value_and_grad(loss_func, argnums=(0, 1, 2))( pot_slices, modes, weights ) return loss, {"pot_slices": grads[0], "modes": grads[1], "weights": grads[2]} # Initialize optimizers optimizer = get_optimizer(optimizer_name) pot_slices_state = optimizer.init(initial_pot_slices.shape) modes_state = optimizer.init(initial_modes.shape) weights_state = optimizer.init(mode_weights.shape) if scheduler_fn is not None: scheduler_state = ptt.init_scheduler_state(learning_rate) # Initialize variables pot_slices = initial_pot_slices modes = initial_modes weights = mode_weights current_lr = learning_rate @jax.jit def update_step( pot_slices, modes, weights, pot_slices_state, modes_state, weights_state, lr ): loss, grads = loss_and_grad(pot_slices, modes, weights) # Update potential slices and modes pot_slices, pot_slices_state = optimizer.update( pot_slices, grads["pot_slices"], pot_slices_state, lr ) modes, modes_state = optimizer.update(modes, grads["modes"], modes_state, lr) # Update weights and normalize weights, weights_state = optimizer.update( weights, grads["weights"], weights_state, lr ) weights = weights / jnp.sum(weights) # Ensure normalization return ( pot_slices, modes, weights, pot_slices_state, modes_state, weights_state, loss, ) for i in range(num_iterations): # Update learning rate if scheduler is provided if scheduler_fn is not None: current_lr, scheduler_state = scheduler_fn(scheduler_state) # Perform optimization step ( pot_slices, modes, weights, pot_slices_state, modes_state, weights_state, loss, ) = update_step( pot_slices, modes, weights, pot_slices_state, modes_state, weights_state, current_lr, ) if i % 100 == 0: print(f"Iteration {i}, Loss: {loss}, LR: {current_lr}") return pot_slices, modes, weights