API Reference
Electrons Module
- class ptyrodactyl.electrons.classes.Complex(*args, **kwargs)
Bases:
AbstractDtype
- class ptyrodactyl.electrons.classes.MixedQuantumStates(states, probabilities)[source]
Bases:
NamedTuple” .. rubric:: Description
PyTree structure for mixed probe quantum states.
- Parameters:
states (Complex[Array, 'H W N'])
probabilities (Float[Array, 'N'])
- - `states` (Complex[Array, "H W N"])
N different states
- - `weights` (Float[Array, "M"])
Occupation probabilities
- probabilities: Float[Array, 'N']
Alias for field number 1
- class ptyrodactyl.electrons.classes.MixedStateParams(num_modes, mode_weights)[source]
Bases:
NamedTuple” .. rubric:: Description
PyTree structure for mixed probe quantum states.
- Parameters:
num_modes (Int[Array, ''])
mode_weights (Float[Array, 'M'])
- - `num_modes` (Int[Array, ""])
number of modes
- - `mode_weights` (Float[Array, "M"])
Weights for each mode
- mode_weights: Float[Array, 'M']
Alias for field number 1
- num_modes: Int[Array, '']
Alias for field number 0
- ptyrodactyl.electrons.classes.NamedTuple(typename, fields=None, /, **kwargs)[source]
Typed version of namedtuple.
Usage:
class Employee(NamedTuple): name: str id: int
This is equivalent to:
Employee = collections.namedtuple('Employee', ['name', 'id'])
The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:
Employee = NamedTuple('Employee', [('name', str), ('id', int)])
- class ptyrodactyl.electrons.classes.ProbeState(modes, weights)[source]
Bases:
NamedTuple” .. rubric:: Description
PyTree structure for electron probe state.
- Parameters:
modes (Complex[Array, 'H W M'])
weights (Float[Array, 'M'])
- - `modes` (Complex[Array, "H W M"])
M is number of modes
- - `weights` (Float[Array, "M"])
Mode occupation numbers
- weights: Float[Array, 'M']
Alias for field number 1
- ptyrodactyl.electrons.classes.register_pytree_node_class(cls)
Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around
register_pytree_node, and provides a class-oriented interface.- Args:
cls: a type to register as a pytree
- Returns:
The input class
clsis returned unchanged after being added to JAX’s pytree registry. This return value allowsregister_pytree_node_classto be used as a decorator.- See also:
register_static(): simpler API for registering a static pytree.register_dataclass(): simpler API for registering a dataclass.
- Examples:
Here we’ll define a custom container that will be compatible with
jax.jit()and other JAX transformations:>>> import jax >>> @jax.tree_util.register_pytree_node_class ... class MyContainer: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten(self): ... return ((self.x, self.y), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... >>> m = MyContainer(jnp.zeros(4), jnp.arange(4)) >>> def f(m): ... return m.x + 2 * m.y >>> jax.jit(f)(m) Array([0., 2., 4., 6.], dtype=float32)
- Parameters:
cls (Typ)
- Return type:
Typ
- class ptyrodactyl.electrons.forward.Complex(*args, **kwargs)
Bases:
AbstractDtype
- ptyrodactyl.electrons.forward.NamedTuple(typename, fields=None, /, **kwargs)[source]
Typed version of namedtuple.
Usage:
class Employee(NamedTuple): name: str id: int
This is equivalent to:
Employee = collections.namedtuple('Employee', ['name', 'id'])
The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:
Employee = NamedTuple('Employee', [('name', str), ('id', int)])
- ptyrodactyl.electrons.forward.aberration(fourier_coord, lambda_angstrom, defocus=Array(0., dtype=float64, weak_type=True), c3=Array(0., dtype=float64, weak_type=True), c5=Array(0., dtype=float64, weak_type=True))[source]
Description
This calculates the aberration function for the electron probe based on the Fourier co-ordinates
Parameters
- fourier_coord (Float[Array, “H W”]):
The Fourier co-ordinates
- lambda_angstrom (Num[Array, “”]):
The wavelength in angstroms
- defocus (Float[Array, “”]):
The defocus value in angstroms. Default is 0.0
- c3 (Float[Array, “”]):
The C3 value in angstroms. Default is 0.0
- c5 (Float[Array, “”]):
The C5 value in angstroms. Default is 0.0
Returns
- chi_probe (Float[Array, “H W”]):
The calculated aberration function
Flow
Calculate the phase shift
Calculate the chi value
Calculate the chi probe value
- Parameters:
fourier_coord (Float[Array, 'H W'])
lambda_angstrom (Num[Array, ''])
defocus (Float[Array, ''] | None)
c3 (Float[Array, ''] | None)
c5 (Float[Array, ''] | None)
- Return type:
Float[Array, ’H W’]
- ptyrodactyl.electrons.forward.cbed(pot_slice, beam, slice_thickness, voltage_kV, calib_ang)[source]
Description
Calculates the CBED pattern for single/multiple slices and single/multiple beam modes. This function computes the Convergent Beam Electron Diffraction (CBED) pattern by propagating one or more beam modes through one or more potential slices.
Parameters
- pot_slice (Complex[Array, “H W #S”]):
The potential slice(s). H and W are height and width, S is the number of slices (optional).
- beam (Complex[Array, “H W #M”]):
The electron beam mode(s). M is the number of modes (optional).
- slice_thickness (Float[Array, “”]):
The thickness of each slice in angstroms.
- voltage_kV (Float[Array, “#v”]):
The accelerating voltage(s) in kilovolts.
- calib_ang (Float[Array, “”]):
The calibration in angstroms.
Returns
- cbed_pattern (Float[Array, “H W”]):
The calculated CBED pattern.
Flow
Ensure 3D arrays even for single slice/mode
Calculate the transmission function for a single slice
Initialize the convolution state
Scan over all slices
Compute the Fourier transform
Compute the intensity for each mode
Sum the intensities across all modes.
- ptyrodactyl.electrons.forward.cbed_no_slice(pot_slice, beam, slice_transmission)[source]
Description
Calculates the CBED pattern for single/multiple slices and single/multiple beam modes.
This function computes the Convergent Beam Electron Diffraction (CBED) pattern by propagating one or more beam modes through one or more potential slices. This version takes in a pre-calculated transmission function for going from one slice to the next, which is useful for calculating the CBED pattern multiple times where the transmission function remains the same., example is 4D-STEM.
Parameters
- pot_slice (Complex[Array, “H W *S”]):
The potential slice(s). H and W are height and width, S is the number of slices (optional).
- beam (Complex[Array, “H W *M”]):
The electron beam mode(s). M is the number of modes (optional).
- slice_transmission (Complex[Array, “*”]):
The pre-calculated transmission function for going from one slice to the next.
Returns
- cbed_pattern (Float[Array, “H W”]):
The calculated CBED pattern.
Flow
Ensure 3D arrays even for single slice/mode
Initialize the convolution state
Scan over all slices
Compute the Fourier transform
Compute the intensity for each mode
Sum the intensities across all modes
- ptyrodactyl.electrons.forward.fourier_calib(real_space_calib, sizebeam)[source]
Description
Generate the Fourier calibration for the beam
Parameters
- real_space_calib (float | Float[Array, “*”]):
The pixel size in angstroms in real space
- sizebeam (Int[Array, “2”]):
The size of the beam in pixels
Returns
- inverse_space_calib (Float[Array, “2”]):
The Fourier calibration in angstroms
Flow
Calculate the field of view in real space
Calculate the inverse space calibration
- Parameters:
real_space_calib (float | Float[Array, '*'])
sizebeam (Int[Array, '2'])
- Return type:
Float[Array, ’2’]
- ptyrodactyl.electrons.forward.fourier_coords(calibration, image_size)[source]
Description
Return the Fourier coordinates
Parameters
- calibration (float):
The pixel size in angstroms in real space
- image_size, (Int[Array, “2”]):
The size of the beam in pixels
Returns
- A NamedTuple with the following fields:
- array (Any[Array, “* *”]):
The array values
- calib_y (float):
Calibration along the first axis
- calib_x (float):
Calibration along the second axis
Flow
Calculate the real space field of view in y and x
Generate the inverse space array y and x
Shift the inverse space array y and x
Create meshgrid of shifted inverse space arrays
Calculate the inverse array
Calculate the calibration in y and x
Return the calibrated array
- Parameters:
calibration (float)
image_size (Int[Array, '2'])
- Return type:
- ptyrodactyl.electrons.forward.initialize_random_modes(shape, num_modes, dtype=<class 'jax.numpy.complex128'>)[source]
Initialize random orthogonal modes.
- ptyrodactyl.electrons.forward.jaxtyped(fn=sentinel, *, typechecker=sentinel)[source]
Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes.
!!! Example
``python # Import both the annotation and the `jaxtyped decorator from jaxtyping from jaxtyping import Array, Float, jaxtyped
# Use your favourite typechecker: usually one of the two lines below. from typeguard import typechecked as typechecker from beartype import beartype as typechecker
# Type-check a function @jaxtyped(typechecker=typechecker) def batch_outer_product(x: Float[Array, “b c1”],
y: Float[Array, “b c2”]
) -> Float[Array, “b c1 c2”]:
return x[:, :, None] * y[:, None, :]
# Type-check a dataclass from dataclasses import dataclass
@jaxtyped(typechecker=typechecker) @dataclass class MyDataclass:
x: int y: Float[Array, “b c”]
Arguments:
- fn: The function or dataclass to decorate. In practice if you want to use
dataclasses with JAX, then [equinox.Module](https://docs.kidger.site/equinox/api/module/module/) is our recommended approach: ```python import equinox as eqx
@jaxtyped(typechecker=typechecker) class MyModule(eqx.Module):
…
- typechecker: Keyword-only argument: the runtime type-checker to use. This should
be a function decorator that will raise an exception if there is a type error, e.g. ```python @typechecker def f(x: int):
pass
f(“a string is not an integer”) # this line should raise an exception ``` Common choices are typechecker=beartype.beartype or typechecker=typeguard.typechecked. Can also be set as typechecker=None to skip automatic runtime type-checking, but still support manual isinstance checks inside the function body: ```python @jaxtyped(typechecker=None) def f(x):
assert isinstance(x, Float[Array, “batch channel”])
Returns:
If fn is a function (including a staticmethod, classmethod, or property), then a wrapped function is returned.
If fn is a dataclass, then fn is returned directly, and additionally its __init__ method is wrapped and modified in-place.
!!! Info “Old syntax”
jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax:
`python @jaxtyped @typechecker def f(...): ... `This is still supported, but will now raise a warning recommending the jaxtyped(typechecker=typechecker) syntax discussed above. (Which will produce easier-to-debug error messages: under the hood, the new syntax more carefully manipulates the typechecker so as to determine where a type-check error arises.)??? Info “Notes for advanced users”
Dynamic contexts:
Put precisely, the axis names in e.g. Float[Array, “batch channels”] and the structure names in e.g. PyTree[int, “T”] are all scoped to the thread-local dynamic context of a jaxtyped-wrapped function. If from within that function we then call another jaxtyped-wrapped function, then a new context is pushed to the stack. The axis sizes and PyTree structures of this inner function will then not be compared against the axis sizes and PyTree structures of the outer function. After the inner function returns then this inner context is popped from the stack, and the previous context is returned to.
isinstance:
Binding of a value against a name is done with an isinstance check, for example isinstance(jnp.zeros((3, 4)), Float[Array, “dim1 dim2”]) will bind dim1=3 and dim2=4. In practice these isinstance checks are usually done by the run-time typechecker typechecker that is supplied as an argument.
This can also be done manually: add isinstance checks inside a function body and they will contribute to the same collection of consistency checks as are performed by the typechecker on the arguments and return values. (Or you can forgo such a typechecker altogether – i.e. typechecker=None – and only do your own manual isinstance checks.)
Only isinstance checks that pass will contribute to the store of values; those that fail will not. As such it is safe to write e.g. assert not isinstance(x, Float32[Array, “foo”]).
Decoupling contexts from function calls:
If you would like to call a new function without creating a new dynamic context (and using the same set of axis and structure values), then simply do not add a jaxtyped decorator to your inner function, whilst continuing to perform type-checking in whatever way you prefer.
Conversely, if you would like a new dynamic context without calling a new function, then in addition to the usage discussed above, jaxtyped also supports being used as a context manager, by passing it the string “context”: ```python with jaxtyped(“context”):
assert isinstance(x, Float[Array, “batch channel”])
``` This is equivalent to placing this code inside a new function wrapped in jaxtyped(typechecker=None). Usage like this is very rare; it’s mostly only useful when working at the global scope.
- ptyrodactyl.electrons.forward.make_probe(aperture, voltage, image_size, calibration_pm, defocus=0, c3=0, c5=0)[source]
Description
This calculates an electron probe based on the size and the estimated Fourier co-ordinates with the option of adding spherical aberration in the form of defocus, C3 and C5
Parameters
- aperture (Union[float, int]):
The aperture size in milliradians
- voltage (Union[float, int]):
The microscope accelerating voltage in kilo electronVolts
- image_size, (Int[Array, “2”]):
The size of the beam in pixels
- calibration_pm (float):
The calibration in picometers
- defocus (float):
The defocus value in angstroms
- c3 (float):
The C3 value in angstroms
- c5 (float):
The C5 value in angstroms
Returns
- probe_real_space (Complex[Array, “H W”]):
The calculated electron probe in real space
Flow
Convert the aperture to radians
Calculate the wavelength in angstroms
Calculate the maximum L value
Calculate the field of view in x and y
Generate the inverse space array y and x
Shift the inverse space array y and x
Create meshgrid of shifted inverse space arrays
Calculate the inverse array
Calculate the calibration in y and x
Calculate the probe in real space
- ptyrodactyl.electrons.forward.normalize_mode_weights(weights)[source]
Normalize mode weights to sum to 1.
- Parameters:
weights (Float[Array, 'M'])
- Return type:
Float[Array, ’M’]
- class ptyrodactyl.electrons.forward.partial[source]
Bases:
objectpartial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.
- args
tuple of arguments to future partial calls
- func
function object to use in future partial calls
- keywords
dictionary of keyword arguments to future partial calls
- ptyrodactyl.electrons.forward.propagation_func(imsize_y, imsize_x, thickness_ang, voltage_kV, calib_ang)[source]
Description
Calculates the complex propagation function that results in the phase shift of the exit wave when it travels from one slice to the next in the multislice algorithm
Parameters
- imsize_y, (int):
Size of the image of the propagator in y axis
- imsize_x, (int):
Size of the image of the propagator in x axis
- thickness_ang, (scalar_number):
Distance between the slices in angstroms
- voltage_kV, (scalar_number):
Accelerating voltage in kilovolts
- calib_ang, (scalar_number):
Calibration or pixel size in angstroms
Returns
- prop (Complex[Array, “H W”]):
The propagation function of the same size given by imsize
Flow
Generate frequency arrays directly using fftfreq
Create 2D meshgrid of frequencies
Calculate squared sum of frequencies
Calculate wavelength
Compute the propagation function
- ptyrodactyl.electrons.forward.shift_beam_fourier(beam, pos, calib_ang)[source]
Description
Shifts the beam to new position(s) using Fourier shifting.
Parameters
- beam (Complex[Array, “H W M”]):
The electron beam modes.
- pos (Float[Array, “… 2”]):
The (y, x) position(s) to shift to in pixels. Can be a single position [2] or multiple […, 2].
- calib_ang (Float[Array, “*”]):
The calibration in angstroms.
Returns
- shifted_beams (Complex[Array, “… H W M”]):
The shifted beam(s) for all position(s) and mode(s).
Flow
Convert positions from real space to Fourier space
Create phase ramps in Fourier space for all positions
Apply shifts to each mode for all positions
- ptyrodactyl.electrons.forward.stem_4d(pot_slice, beam, pos_list, slice_thickness, voltage_kV, calib_ang)[source]
Description
Calculates the 4D-STEM pattern for multiple probe positions with sharding.
Parameters
- pot_slice (Complex[Array, “H W S”]):
The potential slices.
- beam (Complex[Array, “H W M”]):
The electron beam modes.
- pos_list (Float[Array, “P 2”]):
List of (y, x) probe positions in pixels.
- slice_thickness (float):
The thickness of each slice in angstroms.
- voltage_kV (float):
The accelerating voltage in kilovolts.
- calib_ang (float):
The calibration in angstroms.
- mesh (Mesh):
The device mesh for sharding.
Returns
stem_pattern (Float[Array, “P H W”]): The calculated 4D-STEM pattern.
Flow
Calculate the transmission function once
Shift the beam to all positions
Calculate CBED patterns for all positions
- ptyrodactyl.electrons.forward.stem_4d_mixed_state(pot_slices, modes, mode_weights, pos_list, slice_thickness, voltage_kV, calib_ang)[source]
Calculates 4D-STEM pattern for multiple modes and slices with mixed state.
- ptyrodactyl.electrons.forward.stem_4d_multi(pot_slices, beam, pos_list, slice_thickness, voltage_kV, calib_ang)[source]
Calculates 4D-STEM pattern for multiple slices.
This function propagates the beam through multiple slices before calculating the final diffraction pattern.
- ptyrodactyl.electrons.forward.transmission_func(pot_slice, voltage_kV)[source]
Description
Calculates the complex transmission function from a single potential slice at a given electron accelerating voltage.
Because this is JAX - you assume that the input is clean, and you don’t need to check for negative or NaN values. Your preprocessing steps should check for them - not the function itself.
Parameters
- pot_slice (Float[Array, “#a #b”]):
potential slice in Kirkland units
- voltage_kV (scalar_number):
microscope operating voltage in kilo electronVolts
Returns
- trans (Complex[Array, “#a #b”]):
The transmission function of a single crystal slice
Flow
Calculate the electron energy in electronVolts
Calculate the wavelength in angstroms
Calculate the Einstein energy
Calculate the sigma value, which is the constant for the phase shift
Calculate the transmission function as a complex exponential
- Parameters:
pot_slice (Float[Array, '#a #b'])
voltage_kV (Num[Array, ''])
- Return type:
Complex[Array, ’’]
- ptyrodactyl.electrons.forward.typechecker(obj=None, *, conf=BeartypeConf())
Decorate the passed beartypeable (i.e., pure-Python callable or class) with optimal type-checking dynamically generated unique to that beartypeable under the passed beartype configuration.
This decorator supports two distinct (albeit equally efficient) modes of operation:
Decoration mode. The caller activates this mode by passing this decorator a type-checkable object via the
objparameter; this decorator then creates and returns a new callable wrapping that object with optimal type-checking. Specifically:If this object is a callable, this decorator creates and returns a new runtime type-checker (i.e., pure-Python function validating all parameters and returns of all calls to that callable against all PEP-compliant type hints annotating those parameters and returns). The type-checker returned by this decorator is:
Optimized uniquely for the passed callable.
Guaranteed to run in
O(1)constant-time with negligible constant factors.Type-check effectively instantaneously.
Add effectively no runtime overhead to the passed callable.
If the passed object is a class, this decorator iteratively applies itself to all annotated methods of this class by dynamically wrapping each such method with a runtime type-checker (as described previously).
Configuration mode. The caller activates this mode by passing this decorator a beartype configuration via the
confparameter; this decorator then creates and returns a new beartype decorator enabling that configuration. That decorator may then be called (in decoration mode) to create and return a new callable wrapping the passed type-checkable object with optimal type-checking configured by that configuration.
If optimizations are enabled by the active Python interpreter (e.g., due to option
-Opassed to this interpreter), this decorator silently reduces to a noop.Parameters
- objOptional[BeartypeableT]
Beartypeable (i.e., pure-Python callable or class) to be decorated. Defaults to
None, in which case this decorator is in configuration rather than decoration mode. In configuration mode, this decorator creates and returns an efficiently cached private decorator that generically applies the passed beartype configuration to any beartypeable object passed to that decorator. Look… It just works.- confBeartypeConf, optional
Beartype configuration (i.e., self-caching dataclass encapsulating all settings configuring type-checking for the passed object). Defaults to
BeartypeConf(), the defaultO(1)constant-time configuration.
Returns
- BeartypeReturn
Either:
If in decoration mode (i.e.,
objis notNone` while ``confisNone) and:If
objis a callable, a new callable wrapping that callable with dynamically generated type-checking.If
objis a class, this existing class embellished with dynamically generated type-checking.
If in configuration mode (i.e.,
objisNone` while ``confis notNone), a new beartype decorator enabling this configuration.
- Raises:
BeartypeConfException – If the passed configuration is not actually a configuration (i.e., instance of the
BeartypeConfclass).BeartypeDecorHintException –
If any annotation on this callable is neither:
A PEP-compliant type (i.e., instance or class complying with a PEP supported by
beartype), including:A PEP-noncompliant type (i.e., instance or class complying with
beartype-specific semantics rather than a PEP), including:Fully-qualified forward references (i.e., strings specified as fully-qualified classnames).
Tuple unions (i.e., tuples containing one or more classes and/or forward references).
BeartypePep563Exception – If PEP 563 is active for this callable and evaluating a postponed annotation (i.e., annotation whose value is a string) on this callable raises an exception (e.g., due to that annotation referring to local state no longer accessible from this deferred evaluation).
BeartypeDecorParamNameException – If the name of any parameter declared on this callable is prefixed by the reserved substring
__beartype_.BeartypeDecorWrappeeException –
If this callable is either:
Uncallable.
A class, which
beartypecurrently fails to support.A C-based callable (e.g., builtin, third-party C extension).
BeartypeDecorWrapperException – If this decorator erroneously generates a syntactically invalid wrapper function. This should never happen, but here we are, so this probably happened. Please submit an upstream issue with our issue tracker if you ever see this. (Thanks and abstruse apologies!)
- Parameters:
obj (BeartypeableT | None)
conf (BeartypeConf)
- Return type:
BeartypeableT | Callable[[BeartypeableT], BeartypeableT]
- ptyrodactyl.electrons.forward.wavelength_ang(voltage_kV)[source]
Description
Calculates the relativistic electron wavelength in angstroms based on the microscope accelerating voltage.
Because this is JAX - you assume that the input is clean, and you don’t need to check for negative or NaN values. Your preprocessing steps should check for them - not the function itself.
Parameters
- voltage_kV (num_type | Float[Array, “#a”]):
The microscope accelerating voltage in kilo electronVolts
Returns
- `in_angstroms (Float[Array, “*”]):
The electron wavelength in angstroms
Flow
Calculate the electron wavelength in meters
Convert the wavelength to angstroms
- Parameters:
voltage_kV (Num[Array, '#a'])
- Return type:
Float[Array, ’#a’]
- class ptyrodactyl.electrons.inverse.Complex(*args, **kwargs)
Bases:
AbstractDtype
- ptyrodactyl.electrons.inverse.NamedTuple(typename, fields=None, /, **kwargs)[source]
Typed version of namedtuple.
Usage:
class Employee(NamedTuple): name: str id: int
This is equivalent to:
Employee = collections.namedtuple('Employee', ['name', 'id'])
The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:
Employee = NamedTuple('Employee', [('name', str), ('id', int)])
- ptyrodactyl.electrons.inverse.initialize_mixed_states(base_state, num_states, energy_spread=0.5, random_seed=42)[source]
Initialize mixed states for partial temporal coherence.
- Args:
base_state: Base state (e.g., probe) num_states: Number of states in mixture energy_spread: Energy spread in eV (FWHM) random_seed: Random seed for initialization
- Returns:
MixedState with initialized states and probabilities
- Parameters:
- Return type:
- ptyrodactyl.electrons.inverse.initialize_probe_modes(base_probe, num_modes, random_seed=42)[source]
Initialize multiple probe modes from a base probe.
- Args:
base_probe: Base probe function num_modes: Number of modes to generate random_seed: Random seed for initialization
- Returns:
ProbeState with initialized modes and weights
- Parameters:
- Return type:
- ptyrodactyl.electrons.inverse.multi_mode_ptychography(experimental_4dstem, initial_pot_slices, initial_modes, mode_weights, pos_list, slice_thickness, voltage_kV, calib_ang, num_iterations=1000, learning_rate=0.001, loss_type='mse', optimizer_name='adam', scheduler_fn=None)[source]
Multi-mode ptychographic reconstruction with optional mixed state.
- Args:
experimental_4dstem: Experimental 4D-STEM data initial_pot_slices: Initial guess for potential slices initial_modes: Initial guess for probe modes mode_weights: Initial weights for each mode pos_list: List of probe positions slice_thickness: Thickness of each slice voltage_kV: Accelerating voltage calib_ang: Calibration in angstroms num_iterations: Number of optimization iterations learning_rate: Initial learning rate loss_type: Type of loss function optimizer_name: Name of optimizer to use scheduler_fn: Optional learning rate scheduler
- Returns:
Tuple of optimized potential slices, probe modes, and mode weights
- Parameters:
experimental_4dstem (Float[Array, 'P H W'])
initial_pot_slices (Complex[Array, 'H W S'])
initial_modes (Complex[Array, 'H W M'])
mode_weights (Float[Array, 'M'])
pos_list (Float[Array, 'P 2'])
slice_thickness (Float[Array, ''])
voltage_kV (Float[Array, ''])
calib_ang (Float[Array, ''])
num_iterations (int)
learning_rate (float)
loss_type (str)
optimizer_name (str)
scheduler_fn (Callable[[LRSchedulerState], tuple[float, LRSchedulerState]] | None)
- Return type:
tuple[Complex[Array, ’H W S’], Complex[Array, ’H W M’], Float[Array, ’M’]]
- ptyrodactyl.electrons.inverse.multi_slice_ptychography(experimental_4dstem, initial_pot_slices, initial_beam, pos_list, slice_thickness, voltage_kV, calib_ang, num_iterations=1000, learning_rate=0.001, loss_type='mse', optimizer_name='adam', scheduler_fn=None)[source]
Multi-slice ptychographic reconstruction.
- Args:
experimental_4dstem: Experimental 4D-STEM data initial_pot_slices: Initial guess for potential slices initial_beam: Initial guess for electron beam pos_list: List of probe positions slice_thickness: Thickness of each slice voltage_kV: Accelerating voltage calib_ang: Calibration in angstroms num_iterations: Number of optimization iterations learning_rate: Initial learning rate loss_type: Type of loss function optimizer_name: Name of optimizer to use scheduler_fn: Optional learning rate scheduler
- Returns:
Tuple of optimized potential slices and beam
- Parameters:
experimental_4dstem (Float[Array, 'P H W'])
initial_pot_slices (Complex[Array, 'H W S'])
initial_beam (Complex[Array, 'H W'])
pos_list (Float[Array, 'P 2'])
slice_thickness (Float[Array, ''])
voltage_kV (Float[Array, ''])
calib_ang (Float[Array, ''])
num_iterations (int)
learning_rate (float)
loss_type (str)
optimizer_name (str)
scheduler_fn (Callable[[LRSchedulerState], tuple[float, LRSchedulerState]] | None)
- Return type:
- ptyrodactyl.electrons.inverse.single_slice_poscorrected(experimental_4dstem, initial_pot_slice, initial_beam, initial_pos_list, slice_thickness, voltage_kV, calib_ang, devices, num_iterations=1000, learning_rate=0.001, pos_learning_rate=0.1, loss_type='mse', optimizer_name='adam')[source]
Create and run an optimization routine for 4D-STEM reconstruction with position correction.
Args: - experimental_4dstem (Float[Array, “P H W”]):
Experimental 4D-STEM data.
- initial_pot_slice (Complex[Array, “H W”]):
Initial guess for potential slice.
- initial_beam (Complex[Array, “H W”]):
Initial guess for electron beam.
- initial_pos_list (Float[Array, “P 2”]):
Initial list of probe positions.
- slice_thickness (Float[Array, “*”]):
Thickness of each slice.
- voltage_kV (Float[Array, “*”]):
Accelerating voltage.
- calib_ang (Float[Array, “*”]):
Calibration in angstroms.
- devices (jax.Array):
Array of devices for sharding.
- num_iterations (int):
Number of optimization iterations.
- learning_rate (float):
Learning rate for potential slice and beam optimization.
- pos_learning_rate (float):
Learning rate for position optimization.
- loss_type (str):
Type of loss function to use.
Returns: - Tuple[Complex[Array, “H W”], Complex[Array, “H W”], Float[Array, “P 2”]]:
Optimized potential slice, beam, and corrected positions.
- Parameters:
experimental_4dstem (Float[Array, 'P H W'])
initial_pot_slice (Complex[Array, 'H W'])
initial_beam (Complex[Array, 'H W'])
initial_pos_list (Float[Array, 'P 2'])
slice_thickness (Float[Array, '*'])
voltage_kV (Float[Array, '*'])
calib_ang (Float[Array, '*'])
devices (Array)
num_iterations (int)
learning_rate (float)
pos_learning_rate (float)
loss_type (str)
optimizer_name (str)
- Return type:
tuple[Complex[Array, ’H W’], Complex[Array, ’H W’], Float[Array, ’P 2’]]
- ptyrodactyl.electrons.inverse.single_slice_ptychography(experimental_4dstem, initial_pot_slice, initial_beam, pos_list, slice_thickness, voltage_kV, calib_ang, num_iterations=1000, learning_rate=0.001, loss_type='mse', optimizer_name='adam')[source]
Create and run an optimization routine for 4D-STEM reconstruction.
Args: - experimental_4dstem (Float[Array, “P H W”]):
Experimental 4D-STEM data.
- initial_pot_slice (Complex[Array, “H W”]):
Initial guess for potential slice.
- initial_beam (Complex[Array, “H W”]):
Initial guess for electron beam.
- pos_list (Float[Array, “P 2”]):
List of probe positions.
- slice_thickness (Float[Array, “*”]):
Thickness of each slice.
- voltage_kV (Float[Array, “*”]):
Accelerating voltage.
- calib_ang (Float[Array, “*”]):
Calibration in angstroms.
- devices (jax.Array):
Array of devices for sharding.
- num_iterations (int):
Number of optimization iterations.
- learning_rate (float):
Learning rate for optimization.
- loss_type (str):
Type of loss function to use.
Returns: - Tuple[Complex[Array, “H W”], Complex[Array, “H W”]]:
Optimized potential slice and beam.
- Parameters:
experimental_4dstem (Float[Array, 'P H W'])
initial_pot_slice (Complex[Array, 'H W'])
initial_beam (Complex[Array, 'H W'])
pos_list (Float[Array, 'P 2'])
slice_thickness (Float[Array, '*'])
voltage_kV (Float[Array, '*'])
calib_ang (Float[Array, '*'])
num_iterations (int)
learning_rate (float)
loss_type (str)
optimizer_name (str)
- Return type:
Light Microscopy Module
- class ptyrodactyl.optics.epie.Complex(*args, **kwargs)
Bases:
AbstractDtype
- ptyrodactyl.optics.epie.NamedTuple(typename, fields=None, /, **kwargs)[source]
Typed version of namedtuple.
Usage:
class Employee(NamedTuple): name: str id: int
This is equivalent to:
Employee = collections.namedtuple('Employee', ['name', 'id'])
The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:
Employee = NamedTuple('Employee', [('name', str), ('id', int)])
- ptyrodactyl.optics.epie.jaxtyped(fn=sentinel, *, typechecker=sentinel)[source]
Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes.
!!! Example
``python # Import both the annotation and the `jaxtyped decorator from jaxtyping from jaxtyping import Array, Float, jaxtyped
# Use your favourite typechecker: usually one of the two lines below. from typeguard import typechecked as typechecker from beartype import beartype as typechecker
# Type-check a function @jaxtyped(typechecker=typechecker) def batch_outer_product(x: Float[Array, “b c1”],
y: Float[Array, “b c2”]
) -> Float[Array, “b c1 c2”]:
return x[:, :, None] * y[:, None, :]
# Type-check a dataclass from dataclasses import dataclass
@jaxtyped(typechecker=typechecker) @dataclass class MyDataclass:
x: int y: Float[Array, “b c”]
Arguments:
- fn: The function or dataclass to decorate. In practice if you want to use
dataclasses with JAX, then [equinox.Module](https://docs.kidger.site/equinox/api/module/module/) is our recommended approach: ```python import equinox as eqx
@jaxtyped(typechecker=typechecker) class MyModule(eqx.Module):
…
- typechecker: Keyword-only argument: the runtime type-checker to use. This should
be a function decorator that will raise an exception if there is a type error, e.g. ```python @typechecker def f(x: int):
pass
f(“a string is not an integer”) # this line should raise an exception ``` Common choices are typechecker=beartype.beartype or typechecker=typeguard.typechecked. Can also be set as typechecker=None to skip automatic runtime type-checking, but still support manual isinstance checks inside the function body: ```python @jaxtyped(typechecker=None) def f(x):
assert isinstance(x, Float[Array, “batch channel”])
Returns:
If fn is a function (including a staticmethod, classmethod, or property), then a wrapped function is returned.
If fn is a dataclass, then fn is returned directly, and additionally its __init__ method is wrapped and modified in-place.
!!! Info “Old syntax”
jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax:
`python @jaxtyped @typechecker def f(...): ... `This is still supported, but will now raise a warning recommending the jaxtyped(typechecker=typechecker) syntax discussed above. (Which will produce easier-to-debug error messages: under the hood, the new syntax more carefully manipulates the typechecker so as to determine where a type-check error arises.)??? Info “Notes for advanced users”
Dynamic contexts:
Put precisely, the axis names in e.g. Float[Array, “batch channels”] and the structure names in e.g. PyTree[int, “T”] are all scoped to the thread-local dynamic context of a jaxtyped-wrapped function. If from within that function we then call another jaxtyped-wrapped function, then a new context is pushed to the stack. The axis sizes and PyTree structures of this inner function will then not be compared against the axis sizes and PyTree structures of the outer function. After the inner function returns then this inner context is popped from the stack, and the previous context is returned to.
isinstance:
Binding of a value against a name is done with an isinstance check, for example isinstance(jnp.zeros((3, 4)), Float[Array, “dim1 dim2”]) will bind dim1=3 and dim2=4. In practice these isinstance checks are usually done by the run-time typechecker typechecker that is supplied as an argument.
This can also be done manually: add isinstance checks inside a function body and they will contribute to the same collection of consistency checks as are performed by the typechecker on the arguments and return values. (Or you can forgo such a typechecker altogether – i.e. typechecker=None – and only do your own manual isinstance checks.)
Only isinstance checks that pass will contribute to the store of values; those that fail will not. As such it is safe to write e.g. assert not isinstance(x, Float32[Array, “foo”]).
Decoupling contexts from function calls:
If you would like to call a new function without creating a new dynamic context (and using the same set of axis and structure values), then simply do not add a jaxtyped decorator to your inner function, whilst continuing to perform type-checking in whatever way you prefer.
Conversely, if you would like a new dynamic context without calling a new function, then in addition to the usage discussed above, jaxtyped also supports being used as a context manager, by passing it the string “context”: ```python with jaxtyped(“context”):
assert isinstance(x, Float[Array, “batch channel”])
``` This is equivalent to placing this code inside a new function wrapped in jaxtyped(typechecker=None). Usage like this is very rare; it’s mostly only useful when working at the global scope.
- ptyrodactyl.optics.epie.register_pytree_node_class(cls)
Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around
register_pytree_node, and provides a class-oriented interface.- Args:
cls: a type to register as a pytree
- Returns:
The input class
clsis returned unchanged after being added to JAX’s pytree registry. This return value allowsregister_pytree_node_classto be used as a decorator.- See also:
register_static(): simpler API for registering a static pytree.register_dataclass(): simpler API for registering a dataclass.
- Examples:
Here we’ll define a custom container that will be compatible with
jax.jit()and other JAX transformations:>>> import jax >>> @jax.tree_util.register_pytree_node_class ... class MyContainer: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten(self): ... return ((self.x, self.y), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... >>> m = MyContainer(jnp.zeros(4), jnp.arange(4)) >>> def f(m): ... return m.x + 2 * m.y >>> jax.jit(f)(m) Array([0., 2., 4., 6.], dtype=float32)
- Parameters:
cls (Typ)
- Return type:
Typ
- class ptyrodactyl.optics.lenses.Complex(*args, **kwargs)
Bases:
AbstractDtype
- class ptyrodactyl.optics.lenses.LensParams(focal_length, diameter, n, center_thickness, R1, R2)[source]
Bases:
NamedTupleDescription
PyTree structure for lens parameters
- Parameters:
focal_length (Float[Array, ''])
diameter (Float[Array, ''])
n (Float[Array, ''])
center_thickness (Float[Array, ''])
R1 (Float[Array, ''])
R2 (Float[Array, ''])
- - `focal_length` (Float[Array, ""])
Focal length of the lens in meters
- - `diameter` (Float[Array, ""])
Diameter of the lens in meters
- - `n` (Float[Array, ""])
Refractive index of the lens material
- - `center_thickness` (Float[Array, ""])
Thickness at the center of the lens in meters
- - `R1` (Float[Array, ""])
Radius of curvature of the first surface in meters (positive for convex)
- - `R2` (Float[Array, ""])
Radius of curvature of the second surface in meters (positive for convex)
Notes
This class is registered as a PyTree node, making it compatible with JAX transformations like jit, grad, and vmap. The auxiliary data in tree_flatten is None as all relevant data is stored in JAX arrays.
- R1: Float[Array, '']
Alias for field number 4
- R2: Float[Array, '']
Alias for field number 5
- center_thickness: Float[Array, '']
Alias for field number 3
- diameter: Float[Array, '']
Alias for field number 1
- focal_length: Float[Array, '']
Alias for field number 0
- n: Float[Array, '']
Alias for field number 2
- ptyrodactyl.optics.lenses.NamedTuple(typename, fields=None, /, **kwargs)[source]
Typed version of namedtuple.
Usage:
class Employee(NamedTuple): name: str id: int
This is equivalent to:
Employee = collections.namedtuple('Employee', ['name', 'id'])
The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:
Employee = NamedTuple('Employee', [('name', str), ('id', int)])
- ptyrodactyl.optics.lenses.create_lens_phase(X, Y, params, wavelength)[source]
Description
Create the phase profile and transmission mask for a lens.
Parameters
- X (Float[Array, “H W”]):
X coordinates grid
- Y (Float[Array, “H W”]):
Y coordinates grid
- params (LensParams):
Lens parameters
- wavelength (Float[Array, “”]):
Wavelength of light
Returns
- phase_profile (Float[Array, “H W”]):
Phase profile of the lens
- transmission (Float[Array, “H W”]):
Transmission mask of the lens
Flow
Calculate radial coordinates
Calculate thickness profile
Calculate phase profile
Create transmission mask
Return phase and transmission
- Parameters:
X (Float[Array, 'H W'])
Y (Float[Array, 'H W'])
params (LensParams)
wavelength (Float[Array, ''])
- Return type:
tuple[Float[Array, ’H W’], Float[Array, ’H W’]]
- ptyrodactyl.optics.lenses.double_concave_lens(focal_length, diameter, n, center_thickness, R_ratio=Array(1., dtype=float64, weak_type=True))[source]
Description
Create parameters for a double concave lens.
Parameters
- focal_length (Float[Array, “”]):
Desired focal length
- diameter (Float[Array, “”]):
Lens diameter
- n (Float[Array, “”]):
Refractive index
- center_thickness (Float[Array, “”]):
Center thickness
- R_ratio (Optional[Float[Array, “”]]):
Ratio of R2/R1. default is 1.0 for symmetric lens
Returns
- params (LensParams):
Lens parameters
Flow
Calculate R1 using lensmaker’s equation
Calculate R2 using R_ratio
Create and return LensParams
- Parameters:
focal_length (Float[Array, ''])
diameter (Float[Array, ''])
n (Float[Array, ''])
center_thickness (Float[Array, ''])
R_ratio (Float[Array, ''] | None)
- Return type:
- ptyrodactyl.optics.lenses.double_convex_lens(focal_length, diameter, n, center_thickness, R_ratio=Array(1., dtype=float64, weak_type=True))[source]
Description
Create parameters for a double convex lens.
Parameters
- focal_length (Float[Array, “”]):
Desired focal length
- diameter (Float[Array, “”]):
Lens diameter
- n (Float[Array, “”]):
Refractive index
- center_thickness (Float[Array, “”]):
Center thickness
- R_ratio (Optional[Float[Array, “”]]):
Ratio of R2/R1. default is 1.0 for symmetric lens
Returns
- params (LensParams):
Lens parameters
Flow
Calculate R1 using lensmaker’s equation
Calculate R2 using R_ratio
Create and return LensParams
- Parameters:
focal_length (Float[Array, ''])
diameter (Float[Array, ''])
n (Float[Array, ''])
center_thickness (Float[Array, ''])
R_ratio (Float[Array, ''] | None)
- Return type:
- ptyrodactyl.optics.lenses.jaxtyped(fn=sentinel, *, typechecker=sentinel)[source]
Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes.
!!! Example
``python # Import both the annotation and the `jaxtyped decorator from jaxtyping from jaxtyping import Array, Float, jaxtyped
# Use your favourite typechecker: usually one of the two lines below. from typeguard import typechecked as typechecker from beartype import beartype as typechecker
# Type-check a function @jaxtyped(typechecker=typechecker) def batch_outer_product(x: Float[Array, “b c1”],
y: Float[Array, “b c2”]
) -> Float[Array, “b c1 c2”]:
return x[:, :, None] * y[:, None, :]
# Type-check a dataclass from dataclasses import dataclass
@jaxtyped(typechecker=typechecker) @dataclass class MyDataclass:
x: int y: Float[Array, “b c”]
Arguments:
- fn: The function or dataclass to decorate. In practice if you want to use
dataclasses with JAX, then [equinox.Module](https://docs.kidger.site/equinox/api/module/module/) is our recommended approach: ```python import equinox as eqx
@jaxtyped(typechecker=typechecker) class MyModule(eqx.Module):
…
- typechecker: Keyword-only argument: the runtime type-checker to use. This should
be a function decorator that will raise an exception if there is a type error, e.g. ```python @typechecker def f(x: int):
pass
f(“a string is not an integer”) # this line should raise an exception ``` Common choices are typechecker=beartype.beartype or typechecker=typeguard.typechecked. Can also be set as typechecker=None to skip automatic runtime type-checking, but still support manual isinstance checks inside the function body: ```python @jaxtyped(typechecker=None) def f(x):
assert isinstance(x, Float[Array, “batch channel”])
Returns:
If fn is a function (including a staticmethod, classmethod, or property), then a wrapped function is returned.
If fn is a dataclass, then fn is returned directly, and additionally its __init__ method is wrapped and modified in-place.
!!! Info “Old syntax”
jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax:
`python @jaxtyped @typechecker def f(...): ... `This is still supported, but will now raise a warning recommending the jaxtyped(typechecker=typechecker) syntax discussed above. (Which will produce easier-to-debug error messages: under the hood, the new syntax more carefully manipulates the typechecker so as to determine where a type-check error arises.)??? Info “Notes for advanced users”
Dynamic contexts:
Put precisely, the axis names in e.g. Float[Array, “batch channels”] and the structure names in e.g. PyTree[int, “T”] are all scoped to the thread-local dynamic context of a jaxtyped-wrapped function. If from within that function we then call another jaxtyped-wrapped function, then a new context is pushed to the stack. The axis sizes and PyTree structures of this inner function will then not be compared against the axis sizes and PyTree structures of the outer function. After the inner function returns then this inner context is popped from the stack, and the previous context is returned to.
isinstance:
Binding of a value against a name is done with an isinstance check, for example isinstance(jnp.zeros((3, 4)), Float[Array, “dim1 dim2”]) will bind dim1=3 and dim2=4. In practice these isinstance checks are usually done by the run-time typechecker typechecker that is supplied as an argument.
This can also be done manually: add isinstance checks inside a function body and they will contribute to the same collection of consistency checks as are performed by the typechecker on the arguments and return values. (Or you can forgo such a typechecker altogether – i.e. typechecker=None – and only do your own manual isinstance checks.)
Only isinstance checks that pass will contribute to the store of values; those that fail will not. As such it is safe to write e.g. assert not isinstance(x, Float32[Array, “foo”]).
Decoupling contexts from function calls:
If you would like to call a new function without creating a new dynamic context (and using the same set of axis and structure values), then simply do not add a jaxtyped decorator to your inner function, whilst continuing to perform type-checking in whatever way you prefer.
Conversely, if you would like a new dynamic context without calling a new function, then in addition to the usage discussed above, jaxtyped also supports being used as a context manager, by passing it the string “context”: ```python with jaxtyped(“context”):
assert isinstance(x, Float[Array, “batch channel”])
``` This is equivalent to placing this code inside a new function wrapped in jaxtyped(typechecker=None). Usage like this is very rare; it’s mostly only useful when working at the global scope.
- ptyrodactyl.optics.lenses.lens_focal_length(n, R1, R2)[source]
Description
Calculate the focal length of a lens using the lensmaker’s equation.
Parameters
- n (Float[Array, “”]):
Refractive index of the lens material
- R1 (Float[Array, “”]):
Radius of curvature of the first surface (positive for convex)
- R2 (Float[Array, “”]):
Radius of curvature of the second surface (positive for convex)
Returns
- f (Float[Array, “”]):
Focal length of the lens
Flow
Apply the lensmaker’s equation
Return the calculated focal length
- Parameters:
n (Float[Array, ''])
R1 (Float[Array, ''])
R2 (Float[Array, ''])
- Return type:
Float[Array, ’’]
- ptyrodactyl.optics.lenses.lens_thickness_profile(r, R1, R2, center_thickness, diameter)[source]
Description
Calculate the thickness profile of a lens.
Parameters
- r (Float[Array, “H W”]):
Radial distance from the optical axis
- R1 (Float[Array, “”]):
Radius of curvature of the first surface
- R2 (Float[Array, “”]):
Radius of curvature of the second surface
- center_thickness (Float[Array, “”]):
Thickness at the center of the lens
- diameter (Float[Array, “”]):
Diameter of the lens
Returns
- thickness (Float[Array, “H W”]):
Thickness profile of the lens
Flow
Calculate surface sag for both surfaces
Combine sags with center thickness
Apply aperture mask
Return thickness profile
- Parameters:
r (Float[Array, 'H W'])
R1 (Float[Array, ''])
R2 (Float[Array, ''])
center_thickness (Float[Array, ''])
diameter (Float[Array, ''])
- Return type:
Float[Array, ’H W’]
- ptyrodactyl.optics.lenses.meniscus_lens(focal_length, diameter, n, center_thickness, R_ratio, convex_first=Array(True, dtype=bool))[source]
Description
Create parameters for a meniscus (concavo-convex) lens. For a meniscus lens, one surface is convex (positive R) and one is concave (negative R).
Parameters
- focal_length (Float[Array, “”]):
Desired focal length in meters
- diameter (Float[Array, “”]):
Lens diameter in meters
- n (Float[Array, “”]):
Refractive index of lens material
- center_thickness (Float[Array, “”]):
Center thickness in meters
- R_ratio (Float[Array, “”]):
Absolute ratio of R2/R1
- convex_first (Bool[Array, “”]):
If True, first surface is convex (default: True)
Returns
- params (LensParams):
Lens parameters
Flow
Calculate magnitude of R1 using lensmaker’s equation
Calculate R2 magnitude using R_ratio
Assign correct signs based on convex_first
Create and return LensParams
- Parameters:
focal_length (Float[Array, ''])
diameter (Float[Array, ''])
n (Float[Array, ''])
center_thickness (Float[Array, ''])
R_ratio (Float[Array, ''])
convex_first (Bool[Array, ''] | None)
- Return type:
- ptyrodactyl.optics.lenses.plano_concave_lens(focal_length, diameter, n, center_thickness, concave_first=Array(True, dtype=bool))[source]
Description
Create parameters for a plano-concave lens.
Parameters
- focal_length (Float[Array, “”]):
Desired focal length
- diameter (Float[Array, “”]):
Lens diameter
- n (Float[Array, “”]):
Refractive index
- center_thickness (Float[Array, “”]):
Center thickness
- R_ratio (Optional[Float[Array, “”]]):
Ratio of R2/R1. default is 1.0 for symmetric lens
- concave_first (Optional[Bool[Array, “”]]):
If True, first surface is concave (default: True)
Returns
- params (LensParams):
Lens parameters
Flow
Calculate R for curved surface
Set other R to infinity (flat surface)
Create and return LensParams
- Parameters:
focal_length (Float[Array, ''])
diameter (Float[Array, ''])
n (Float[Array, ''])
center_thickness (Float[Array, ''])
concave_first (Bool[Array, ''] | None)
- Return type:
- ptyrodactyl.optics.lenses.plano_convex_lens(focal_length, diameter, n, center_thickness, convex_first=Array(True, dtype=bool))[source]
Description
Create parameters for a plano-convex lens.
Parameters
- focal_length (Float[Array, “”]):
Desired focal length
- diameter (Float[Array, “”]):
Lens diameter
- n (Float[Array, “”]):
Refractive index
- center_thickness (Float[Array, “”]):
Center thickness
- R_ratio (Optional[Float[Array, “”]]):
Ratio of R2/R1. default is 1.0 for symmetric lens
- convex_first (Optional[Bool[Array, “”]]):
If True, first surface is convex. Default: True
Returns
- params (LensParams):
Lens parameters
Flow
Calculate R for curved surface
Set other R to infinity (flat surface)
Create and return LensParams
- Parameters:
focal_length (Float[Array, ''])
diameter (Float[Array, ''])
n (Float[Array, ''])
center_thickness (Float[Array, ''])
convex_first (Bool[Array, ''] | None)
- Return type:
- ptyrodactyl.optics.lenses.propagate_through_lens(field, phase_profile, transmission)[source]
Description
Propagate a field through a lens.
Parameters
- field (Complex[Array, “H W”]):
Input complex field
- phase_profile (Float[Array, “H W”]):
Phase profile of the lens
- transmission (Float[Array, “H W”]):
Transmission mask of the lens
Returns
- output_field (Complex[Array, “H W”]):
Field after passing through the lens
Flow
Apply transmission mask
Add phase profile
Return modified field
- ptyrodactyl.optics.lenses.register_pytree_node_class(cls)
Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around
register_pytree_node, and provides a class-oriented interface.- Args:
cls: a type to register as a pytree
- Returns:
The input class
clsis returned unchanged after being added to JAX’s pytree registry. This return value allowsregister_pytree_node_classto be used as a decorator.- See also:
register_static(): simpler API for registering a static pytree.register_dataclass(): simpler API for registering a dataclass.
- Examples:
Here we’ll define a custom container that will be compatible with
jax.jit()and other JAX transformations:>>> import jax >>> @jax.tree_util.register_pytree_node_class ... class MyContainer: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten(self): ... return ((self.x, self.y), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... >>> m = MyContainer(jnp.zeros(4), jnp.arange(4)) >>> def f(m): ... return m.x + 2 * m.y >>> jax.jit(f)(m) Array([0., 2., 4., 6.], dtype=float32)
- Parameters:
cls (Typ)
- Return type:
Typ
Tools Module
- class ptyrodactyl.tools.loss_functions.Any(*args, **kwargs)[source]
Bases:
objectSpecial type indicating an unconstrained type.
Any is compatible with every type.
Any assumed to have all methods.
All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
- class ptyrodactyl.tools.loss_functions.Complex(*args, **kwargs)
Bases:
AbstractDtype
- class ptyrodactyl.tools.loss_functions.PyTree(*args, **kwargs)
Bases:
objectRepresents a PyTree.
Annotations of the following sorts are supported:
`python a: PyTree b: PyTree[LeafType] c: PyTree[LeafType, "T"] d: PyTree[LeafType, "S T"] e: PyTree[LeafType, "... T"] f: PyTree[LeafType, "T ..."] `These correspond to:
- A plain PyTree can be used an annotation, in which case PyTree is simply a
suggestively-named alternative to Any. ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
- PyTree[LeafType] denotes a PyTree all of whose leaves match LeafType. For
example, PyTree[int] or PyTree[Union[str, Float32[Array, “b c”]]].
- A structure name can also be passed. In this case
jax.tree_util.tree_structure(…) will be called, and bound to the structure name. This can be used to mark that multiple PyTrees all have the same structure: ```python def f(x: PyTree[int, “T”], y: PyTree[int, “T”]):
…
``` Structures are bound to names in the same way as array shape annotations, i.e. within the thread-local dynamic context of a [jaxtyping.jaxtyped][] decorator.
- A composite structure can be declared. In this case the variable must have a PyTree
structure each to the composition of multiple previously-bound PyTree structures. For example: ```python def f(x: PyTree[int, “T”], y: PyTree[int, “S”], z: PyTree[int, “S T”]):
…
x = (1, 2) y = {“key”: 3} z = {“key”: (4, 5)} # structure is the composition of the structures of y and z f(x, y, z) ``` When performing runtime type-checking, all the individual pieces must have already been bound to structures, otherwise the composite structure check will throw an error.
- A structure can begin with a …, to denote that the lower levels of the PyTree
must match the declared structure, but the upper levels can be arbitrary. As in the previous case, all named pieces must already have been seen and their structures bound.
- A structure can end with a …, to denote that the PyTree must be a prefix of the
declared structure, but the lower levels can be arbitrary. As in the previous two cases, all named pieces must already have been seen and their structures bound.
- ptyrodactyl.tools.loss_functions.create_loss_function(forward_function, experimental_data, loss_type='mae')[source]
Create a JIT-compatible loss function for comparing model output with experimental data.
This function returns a new function that computes the loss between the output of a forward model and experimental data. The returned function is JIT-compatible and can be used with various optimization algorithms.
Args: - forward_function (Callable[…, Array]):
The forward model function (e.g., stem_4d).
- experimental_data (Array):
The experimental data to compare against.
- loss_type (Literal[“mae”, “mse”, “rmse”]):
The type of loss to use. Options are “mae” (Mean Absolute Error), “mse” (Mean Squared Error), or “rmse” (Root Mean Squared Error). Default is “mae”.
Loss Functions: - mae_loss:
Mean Absolute Error loss function.
- mse_loss:
Mean Squared Error loss function.
- rmse_loss:
Root Mean Squared Error loss function.
Returns: - loss_fn (Callable[[PyTree, …], Float[Array, “”]]):
A JIT-compatible function that computes the loss given the model parameters and any additional arguments required by the forward function.
- class ptyrodactyl.tools.optimizers.Any(*args, **kwargs)[source]
Bases:
objectSpecial type indicating an unconstrained type.
Any is compatible with every type.
Any assumed to have all methods.
All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
- class ptyrodactyl.tools.optimizers.Complex(*args, **kwargs)
Bases:
AbstractDtype
- class ptyrodactyl.tools.optimizers.LRSchedulerState(step, learning_rate, initial_lr)[source]
Bases:
NamedTupleState maintained by learning rate schedulers.
- ptyrodactyl.tools.optimizers.NamedTuple(typename, fields=None, /, **kwargs)[source]
Typed version of namedtuple.
Usage:
class Employee(NamedTuple): name: str id: int
This is equivalent to:
Employee = collections.namedtuple('Employee', ['name', 'id'])
The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:
Employee = NamedTuple('Employee', [('name', str), ('id', int)])
- class ptyrodactyl.tools.optimizers.Optimizer(init, update)[source]
Bases:
NamedTuple
- class ptyrodactyl.tools.optimizers.OptimizerState(m, v, step)[source]
Bases:
NamedTuple
- class ptyrodactyl.tools.optimizers.Sequence
Bases:
Reversible,CollectionAll the operations on a read-only sequence.
Concrete subclasses must override __new__ or __init__, __getitem__, and __len__.
- count(value) integer -- return number of occurrences of value
- index(value[, start[, stop]]) integer -- return first index of value.
Raises ValueError if the value is not present.
Supporting start and stop arguments is optional, but recommended.
- ptyrodactyl.tools.optimizers.adagrad_update(params, grads, state, learning_rate=0.01, eps=1e-08)[source]
- Parameters:
params (Complex[Array, '...'])
grads (Complex[Array, '...'])
state (OptimizerState)
learning_rate (float)
eps (float)
- Return type:
tuple[Complex[Array, ’…’], OptimizerState]
- ptyrodactyl.tools.optimizers.adam_update(params, grads, state, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08)[source]
- Parameters:
- Return type:
tuple[Complex[Array, ’…’], OptimizerState]
- ptyrodactyl.tools.optimizers.complex_adagrad(params, grads, state, learning_rate=0.01, eps=1e-08)[source]
Complex-valued Adagrad optimizer based on Wirtinger derivatives.
This function performs one step of the Adagrad optimization algorithm for complex-valued parameters.
Args: - params (Complex[Array, “…”]):
Current complex-valued parameters.
- grads (Complex[Array, “…”]):
Complex-valued gradients.
- state (Complex[Array, “…”]):
Optimizer state (accumulated squared gradients).
- learning_rate (float):
Learning rate (default: 0.01).
- eps (float):
Small value to avoid division by zero (default: 1e-8).
Returns: - new_params (Complex[Array, “…”]): Updated parameters. - new_state (Complex[Array, “…”]): Updated optimizer state.
- ptyrodactyl.tools.optimizers.complex_adam(params, grads, state, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08)[source]
Complex-valued Adam optimizer based on Wirtinger derivatives.
This function performs one step of the Adam optimization algorithm for complex-valued parameters.
Args: - params (Complex[Array, “…”]):
Current complex-valued parameters.
- grads (Complex[Array, “…”]):
Complex-valued gradients.
- state (Tuple[Complex[Array, “…”], Complex[Array, “…”], int]):
Optimizer state (first moment, second moment, timestep).
- learning_rate (float):
Learning rate (default: 0.001).
- beta1 (float):
Exponential decay rate for first moment estimates (default: 0.9).
- beta2 (float):
Exponential decay rate for second moment estimates (default: 0.999).
- eps (float):
Small value to avoid division by zero (default: 1e-8).
Returns: - new_params (Complex[Array, “…”]):
Updated parameters.
- new_state (Tuple[Complex[Array, “…”], Complex[Array, “…”], int]):
Updated optimizer state.
- ptyrodactyl.tools.optimizers.complex_rmsprop(params, grads, state, learning_rate=0.001, decay_rate=0.9, eps=1e-08)[source]
Complex-valued RMSprop optimizer based on Wirtinger derivatives.
This function performs one step of the RMSprop optimization algorithm for complex-valued parameters.
Args: - params (Complex[Array, “…”]):
Current complex-valued parameters.
- grads (Complex[Array, “…”]):
Complex-valued gradients.
- state (Complex[Array, “…”]):
Optimizer state (moving average of squared gradients).
- learning_rate (float):
Learning rate (default: 0.001).
- decay_rate (float):
Decay rate for moving average (default: 0.9).
- eps (float):
Small value to avoid division by zero (default: 1e-8).
Returns: - new_params (Complex[Array, “…”]): Updated parameters. - new_state (Complex[Array, “…”]): Updated optimizer state.
- ptyrodactyl.tools.optimizers.create_cosine_scheduler(total_steps, final_lr_factor=0.01)[source]
Creates a cosine learning rate scheduler.
- Args:
total_steps: Total number of optimization steps final_lr_factor: Final learning rate as a fraction of initial
- Parameters:
- Return type:
Callable[[LRSchedulerState], tuple[float, LRSchedulerState]]
- ptyrodactyl.tools.optimizers.create_step_scheduler(step_size, gamma=0.1)[source]
Creates a step decay scheduler that reduces learning rate by gamma every step_size steps.
- Args:
step_size: Number of steps between learning rate drops gamma: Multiplicative factor for learning rate decay
- Parameters:
- Return type:
Callable[[LRSchedulerState], tuple[float, LRSchedulerState]]
- ptyrodactyl.tools.optimizers.create_warmup_cosine_scheduler(total_steps, warmup_steps, final_lr_factor=0.01)[source]
Creates a scheduler with linear warmup followed by cosine decay.
- Args:
total_steps: Total number of optimization steps warmup_steps: Number of warmup steps final_lr_factor: Final learning rate as a fraction of initial
- Parameters:
- Return type:
Callable[[LRSchedulerState], tuple[float, LRSchedulerState]]
- ptyrodactyl.tools.optimizers.init_scheduler_state(initial_lr)[source]
Initialize scheduler state with given learning rate.
- Parameters:
initial_lr (float)
- Return type:
- ptyrodactyl.tools.optimizers.rmsprop_update(params, grads, state, learning_rate=0.001, decay_rate=0.9, eps=1e-08)[source]
- Parameters:
- Return type:
tuple[Complex[Array, ’…’], OptimizerState]
- ptyrodactyl.tools.optimizers.wirtinger_grad(f, argnums=0)[source]
Compute the Wirtinger gradient of a complex-valued function.
This function returns a new function that computes the Wirtinger gradient of the input function f with respect to the specified argument(s).
Args: - f (Callable[…, Float[Array, “…”]]):
A complex-valued function to differentiate.
- argnums (Union[int, Sequence[int]]):
Specifies which argument(s) to compute the gradient with respect to. Can be an int or a sequence of ints. Default is 0.
Returns: - grad_f (Callable[…, Union[Complex[Array, “…”], Tuple[Complex[Array, “…”], …]]]):
A function that computes the Wirtinger gradient of f with respect to the specified argument(s).
- class ptyrodactyl.tools.parallel.Mesh(devices, axis_names, *, axis_types=None)[source]
Bases:
ContextDecoratorDeclare the hardware resources available in the scope of this manager.
In particular, all
axis_namesbecome valid resource names inside the managed block and can be used e.g. in thein_axis_resourcesargument ofjax.experimental.pjit.pjit(). Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)If you are compiling in multiple threads, make sure that the
with Meshcontext manager is inside the function that the threads will execute.- Args:
- devices: A NumPy ndarray object containing JAX device objects (as
obtained e.g. from
jax.devices()).- axis_names: A sequence of resource axis names to be assigned to the
dimensions of the
devicesargument. Its length should match the rank ofdevices.
Examples:
>>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> inp = np.arange(16).reshape((8, 2)) >>> devices = np.array(jax.devices()).reshape(4, 2) ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> # Use the mesh object directly as a context manager. >>> with global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager. >>> with Mesh(devices, ('x', 'y')) as global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> with global_mesh as m: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`. >>> with Mesh(devices, ('x', 'y')): ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
- Parameters:
- property abstract_mesh
- property device_ids
- property empty
- property is_multi_process
- property local_devices
- property local_mesh
- property shape
- property shape_tuple
- property size
- class ptyrodactyl.tools.parallel.NamedSharding(*args, **kwargs)
Bases:
ShardingA
NamedShardingexpresses sharding using named axes.A
NamedShardingis a pair of aMeshof devices andPartitionSpecwhich describes how to shard an array across that mesh.A
Meshis a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g.'x'or'y'.A
PartitionSpecis a tuple, whose elements can be aNone, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example,PartitionSpec('x', 'y')says that the first dimension of data is sharded acrossxaxis of the mesh, and the second dimension is sharded acrossyaxis of the mesh.The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how
MeshandPartitionSpecare used.- Args:
mesh: A
jax.sharding.Meshobject. spec: Ajax.sharding.PartitionSpecobject.
Examples:
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- property addressable_devices: set[Device]
The set of devices in the
Shardingthat are addressable by the current process.
- property device_set: set[Device]
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property is_fully_addressable: bool
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding.is_fully_addressableis equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- property mesh
(self) -> object
- property spec
(self) -> object
- with_memory_kind(kind)[source]
Returns a new Sharding instance with the specified memory kind.
- Parameters:
kind (str)
- Return type:
- with_spec(spec)[source]
- Parameters:
spec (PartitionSpec | Sequence[Any])
- Return type:
- class ptyrodactyl.tools.parallel.PartitionSpec(*partitions)[source]
Bases:
tupleTuple describing how to partition an array across a mesh of devices.
Each element is either
None, a string, or a tuple of strings. See the documentation ofjax.sharding.NamedShardingfor more details.This class exists so JAX’s pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees.
- UNCONSTRAINED = UNCONSTRAINED
- ptyrodactyl.tools.parallel.shard_array(input_array, shard_axes, devices=None)[source]
Shards an array across specified axes and devices.
Args: - input_array (Array):
The input input_array to be sharded.
- shard_axes (Union[int, Sequence[int]]):
The axis or axes to shard along. Use -1 or sequence of -1s to not shard along any axis.
- devices (Sequence[jax.Device], optional):
The devices to shard across. If None, uses all available devices.
Returns: - sharded_array (Array):
The sharded array.