API Reference

Electrons Module

class ptyrodactyl.electrons.classes.Complex(*args, **kwargs)

Bases: AbstractDtype

dtypes: Literal[_any_dtype] | list[str | Pattern] = ('complex64', 'complex128')
class ptyrodactyl.electrons.classes.MixedQuantumStates(states, probabilities)[source]

Bases: NamedTuple

” .. rubric:: Description

PyTree structure for mixed probe quantum states.

Parameters:
  • states (Complex[Array, 'H W N'])

  • probabilities (Float[Array, 'N'])

- `states` (Complex[Array, "H W N"])

N different states

- `weights` (Float[Array, "M"])

Occupation probabilities

probabilities: Float[Array, 'N']

Alias for field number 1

states: Complex[Array, 'H W N']

Alias for field number 0

tree_flatten()[source]
classmethod tree_unflatten(aux_data, children)[source]
class ptyrodactyl.electrons.classes.MixedStateParams(num_modes, mode_weights)[source]

Bases: NamedTuple

” .. rubric:: Description

PyTree structure for mixed probe quantum states.

Parameters:
  • num_modes (Int[Array, ''])

  • mode_weights (Float[Array, 'M'])

- `num_modes` (Int[Array, ""])

number of modes

- `mode_weights` (Float[Array, "M"])

Weights for each mode

mode_weights: Float[Array, 'M']

Alias for field number 1

num_modes: Int[Array, '']

Alias for field number 0

tree_flatten()[source]
classmethod tree_unflatten(aux_data, children)[source]
ptyrodactyl.electrons.classes.NamedTuple(typename, fields=None, /, **kwargs)[source]

Typed version of namedtuple.

Usage:

class Employee(NamedTuple):
    name: str
    id: int

This is equivalent to:

Employee = collections.namedtuple('Employee', ['name', 'id'])

The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:

Employee = NamedTuple('Employee', [('name', str), ('id', int)])
class ptyrodactyl.electrons.classes.ProbeState(modes, weights)[source]

Bases: NamedTuple

” .. rubric:: Description

PyTree structure for electron probe state.

Parameters:
  • modes (Complex[Array, 'H W M'])

  • weights (Float[Array, 'M'])

- `modes` (Complex[Array, "H W M"])

M is number of modes

- `weights` (Float[Array, "M"])

Mode occupation numbers

modes: Complex[Array, 'H W M']

Alias for field number 0

tree_flatten()[source]
classmethod tree_unflatten(aux_data, children)[source]
weights: Float[Array, 'M']

Alias for field number 1

ptyrodactyl.electrons.classes.register_pytree_node_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

This function is a thin wrapper around register_pytree_node, and provides a class-oriented interface.

Args:

cls: a type to register as a pytree

Returns:

The input class cls is returned unchanged after being added to JAX’s pytree registry. This return value allows register_pytree_node_class to be used as a decorator.

See also:
Examples:

Here we’ll define a custom container that will be compatible with jax.jit() and other JAX transformations:

>>> import jax
>>> @jax.tree_util.register_pytree_node_class
... class MyContainer:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten(self):
...     return ((self.x, self.y), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)
...
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
>>> def f(m):
...   return m.x + 2 * m.y
>>> jax.jit(f)(m)
Array([0., 2., 4., 6.], dtype=float32)
Parameters:

cls (Typ)

Return type:

Typ

class ptyrodactyl.electrons.forward.Complex(*args, **kwargs)

Bases: AbstractDtype

dtypes: Literal[_any_dtype] | list[str | Pattern] = ('complex64', 'complex128')
ptyrodactyl.electrons.forward.NamedTuple(typename, fields=None, /, **kwargs)[source]

Typed version of namedtuple.

Usage:

class Employee(NamedTuple):
    name: str
    id: int

This is equivalent to:

Employee = collections.namedtuple('Employee', ['name', 'id'])

The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:

Employee = NamedTuple('Employee', [('name', str), ('id', int)])
ptyrodactyl.electrons.forward.Tuple

alias of tuple

ptyrodactyl.electrons.forward.aberration(fourier_coord, lambda_angstrom, defocus=Array(0., dtype=float64, weak_type=True), c3=Array(0., dtype=float64, weak_type=True), c5=Array(0., dtype=float64, weak_type=True))[source]

Description

This calculates the aberration function for the electron probe based on the Fourier co-ordinates

Parameters

  • fourier_coord (Float[Array, “H W”]):

    The Fourier co-ordinates

  • lambda_angstrom (Num[Array, “”]):

    The wavelength in angstroms

  • defocus (Float[Array, “”]):

    The defocus value in angstroms. Default is 0.0

  • c3 (Float[Array, “”]):

    The C3 value in angstroms. Default is 0.0

  • c5 (Float[Array, “”]):

    The C5 value in angstroms. Default is 0.0

Returns

  • chi_probe (Float[Array, “H W”]):

    The calculated aberration function

Flow

  • Calculate the phase shift

  • Calculate the chi value

  • Calculate the chi probe value

Parameters:
  • fourier_coord (Float[Array, 'H W'])

  • lambda_angstrom (Num[Array, ''])

  • defocus (Float[Array, ''] | None)

  • c3 (Float[Array, ''] | None)

  • c5 (Float[Array, ''] | None)

Return type:

Float[Array, ’H W’]

ptyrodactyl.electrons.forward.cbed(pot_slice, beam, slice_thickness, voltage_kV, calib_ang)[source]

Description

Calculates the CBED pattern for single/multiple slices and single/multiple beam modes. This function computes the Convergent Beam Electron Diffraction (CBED) pattern by propagating one or more beam modes through one or more potential slices.

Parameters

  • pot_slice (Complex[Array, “H W #S”]):

    The potential slice(s). H and W are height and width, S is the number of slices (optional).

  • beam (Complex[Array, “H W #M”]):

    The electron beam mode(s). M is the number of modes (optional).

  • slice_thickness (Float[Array, “”]):

    The thickness of each slice in angstroms.

  • voltage_kV (Float[Array, “#v”]):

    The accelerating voltage(s) in kilovolts.

  • calib_ang (Float[Array, “”]):

    The calibration in angstroms.

Returns

  • cbed_pattern (Float[Array, “H W”]):

    The calculated CBED pattern.

Flow

  • Ensure 3D arrays even for single slice/mode

  • Calculate the transmission function for a single slice

  • Initialize the convolution state

  • Scan over all slices

  • Compute the Fourier transform

  • Compute the intensity for each mode

  • Sum the intensities across all modes.

Parameters:
  • pot_slice (Complex[Array, 'H W #S'])

  • beam (Complex[Array, 'H W #M'])

  • slice_thickness (Float[Array, ''])

  • voltage_kV (Float[Array, '#v'])

  • calib_ang (Float[Array, ''])

Return type:

Float[Array, ’H W’]

ptyrodactyl.electrons.forward.cbed_no_slice(pot_slice, beam, slice_transmission)[source]

Description

Calculates the CBED pattern for single/multiple slices and single/multiple beam modes.

This function computes the Convergent Beam Electron Diffraction (CBED) pattern by propagating one or more beam modes through one or more potential slices. This version takes in a pre-calculated transmission function for going from one slice to the next, which is useful for calculating the CBED pattern multiple times where the transmission function remains the same., example is 4D-STEM.

Parameters

  • pot_slice (Complex[Array, “H W *S”]):

    The potential slice(s). H and W are height and width, S is the number of slices (optional).

  • beam (Complex[Array, “H W *M”]):

    The electron beam mode(s). M is the number of modes (optional).

  • slice_transmission (Complex[Array, “*”]):

    The pre-calculated transmission function for going from one slice to the next.

Returns

  • cbed_pattern (Float[Array, “H W”]):

    The calculated CBED pattern.

Flow

  • Ensure 3D arrays even for single slice/mode

  • Initialize the convolution state

  • Scan over all slices

  • Compute the Fourier transform

  • Compute the intensity for each mode

  • Sum the intensities across all modes

Parameters:
  • pot_slice (Complex[Array, 'H W *S'])

  • beam (Complex[Array, 'H W *M'])

  • slice_transmission (Complex[Array, '*'])

Return type:

Float[Array, ’H W’]

ptyrodactyl.electrons.forward.fourier_calib(real_space_calib, sizebeam)[source]

Description

Generate the Fourier calibration for the beam

Parameters

  • real_space_calib (float | Float[Array, “*”]):

    The pixel size in angstroms in real space

  • sizebeam (Int[Array, “2”]):

    The size of the beam in pixels

Returns

  • inverse_space_calib (Float[Array, “2”]):

    The Fourier calibration in angstroms

Flow

  • Calculate the field of view in real space

  • Calculate the inverse space calibration

Parameters:
  • real_space_calib (float | Float[Array, '*'])

  • sizebeam (Int[Array, '2'])

Return type:

Float[Array, ’2’]

ptyrodactyl.electrons.forward.fourier_coords(calibration, image_size)[source]

Description

Return the Fourier coordinates

Parameters

  • calibration (float):

    The pixel size in angstroms in real space

  • image_size, (Int[Array, “2”]):

    The size of the beam in pixels

Returns

  • A NamedTuple with the following fields:
    • array (Any[Array, “* *”]):

      The array values

    • calib_y (float):

      Calibration along the first axis

    • calib_x (float):

      Calibration along the second axis

Flow

  • Calculate the real space field of view in y and x

  • Generate the inverse space array y and x

  • Shift the inverse space array y and x

  • Create meshgrid of shifted inverse space arrays

  • Calculate the inverse array

  • Calculate the calibration in y and x

  • Return the calibrated array

Parameters:
  • calibration (float)

  • image_size (Int[Array, '2'])

Return type:

NamedTuple

ptyrodactyl.electrons.forward.initialize_random_modes(shape, num_modes, dtype=<class 'jax.numpy.complex128'>)[source]

Initialize random orthogonal modes.

Parameters:
Return type:

Complex[Array, ’H W M’]

ptyrodactyl.electrons.forward.jaxtyped(fn=sentinel, *, typechecker=sentinel)[source]

Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes.

!!! Example

``python # Import both the annotation and the `jaxtyped decorator from jaxtyping from jaxtyping import Array, Float, jaxtyped

# Use your favourite typechecker: usually one of the two lines below. from typeguard import typechecked as typechecker from beartype import beartype as typechecker

# Type-check a function @jaxtyped(typechecker=typechecker) def batch_outer_product(x: Float[Array, “b c1”],

y: Float[Array, “b c2”]

) -> Float[Array, “b c1 c2”]:

return x[:, :, None] * y[:, None, :]

# Type-check a dataclass from dataclasses import dataclass

@jaxtyped(typechecker=typechecker) @dataclass class MyDataclass:

x: int y: Float[Array, “b c”]

```

Arguments:

  • fn: The function or dataclass to decorate. In practice if you want to use

    dataclasses with JAX, then [equinox.Module](https://docs.kidger.site/equinox/api/module/module/) is our recommended approach: ```python import equinox as eqx

    @jaxtyped(typechecker=typechecker) class MyModule(eqx.Module):

    ```

  • typechecker: Keyword-only argument: the runtime type-checker to use. This should

    be a function decorator that will raise an exception if there is a type error, e.g. ```python @typechecker def f(x: int):

    pass

    f(“a string is not an integer”) # this line should raise an exception ``` Common choices are typechecker=beartype.beartype or typechecker=typeguard.typechecked. Can also be set as typechecker=None to skip automatic runtime type-checking, but still support manual isinstance checks inside the function body: ```python @jaxtyped(typechecker=None) def f(x):

    assert isinstance(x, Float[Array, “batch channel”])

    ```

Returns:

If fn is a function (including a staticmethod, classmethod, or property), then a wrapped function is returned.

If fn is a dataclass, then fn is returned directly, and additionally its __init__ method is wrapped and modified in-place.

!!! Info “Old syntax”

jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax: `python @jaxtyped @typechecker def f(...): ... ` This is still supported, but will now raise a warning recommending the jaxtyped(typechecker=typechecker) syntax discussed above. (Which will produce easier-to-debug error messages: under the hood, the new syntax more carefully manipulates the typechecker so as to determine where a type-check error arises.)

??? Info “Notes for advanced users”

Dynamic contexts:

Put precisely, the axis names in e.g. Float[Array, “batch channels”] and the structure names in e.g. PyTree[int, “T”] are all scoped to the thread-local dynamic context of a jaxtyped-wrapped function. If from within that function we then call another jaxtyped-wrapped function, then a new context is pushed to the stack. The axis sizes and PyTree structures of this inner function will then not be compared against the axis sizes and PyTree structures of the outer function. After the inner function returns then this inner context is popped from the stack, and the previous context is returned to.

isinstance:

Binding of a value against a name is done with an isinstance check, for example isinstance(jnp.zeros((3, 4)), Float[Array, “dim1 dim2”]) will bind dim1=3 and dim2=4. In practice these isinstance checks are usually done by the run-time typechecker typechecker that is supplied as an argument.

This can also be done manually: add isinstance checks inside a function body and they will contribute to the same collection of consistency checks as are performed by the typechecker on the arguments and return values. (Or you can forgo such a typechecker altogether – i.e. typechecker=None – and only do your own manual isinstance checks.)

Only isinstance checks that pass will contribute to the store of values; those that fail will not. As such it is safe to write e.g. assert not isinstance(x, Float32[Array, “foo”]).

Decoupling contexts from function calls:

If you would like to call a new function without creating a new dynamic context (and using the same set of axis and structure values), then simply do not add a jaxtyped decorator to your inner function, whilst continuing to perform type-checking in whatever way you prefer.

Conversely, if you would like a new dynamic context without calling a new function, then in addition to the usage discussed above, jaxtyped also supports being used as a context manager, by passing it the string “context”: ```python with jaxtyped(“context”):

assert isinstance(x, Float[Array, “batch channel”])

``` This is equivalent to placing this code inside a new function wrapped in jaxtyped(typechecker=None). Usage like this is very rare; it’s mostly only useful when working at the global scope.

ptyrodactyl.electrons.forward.make_probe(aperture, voltage, image_size, calibration_pm, defocus=0, c3=0, c5=0)[source]

Description

This calculates an electron probe based on the size and the estimated Fourier co-ordinates with the option of adding spherical aberration in the form of defocus, C3 and C5

Parameters

  • aperture (Union[float, int]):

    The aperture size in milliradians

  • voltage (Union[float, int]):

    The microscope accelerating voltage in kilo electronVolts

  • image_size, (Int[Array, “2”]):

    The size of the beam in pixels

  • calibration_pm (float):

    The calibration in picometers

  • defocus (float):

    The defocus value in angstroms

  • c3 (float):

    The C3 value in angstroms

  • c5 (float):

    The C5 value in angstroms

Returns

  • probe_real_space (Complex[Array, “H W”]):

    The calculated electron probe in real space

Flow

  • Convert the aperture to radians

  • Calculate the wavelength in angstroms

  • Calculate the maximum L value

  • Calculate the field of view in x and y

  • Generate the inverse space array y and x

  • Shift the inverse space array y and x

  • Create meshgrid of shifted inverse space arrays

  • Calculate the inverse array

  • Calculate the calibration in y and x

  • Calculate the probe in real space

Parameters:
Return type:

Complex[Array, ’H W’]

ptyrodactyl.electrons.forward.normalize_mode_weights(weights)[source]

Normalize mode weights to sum to 1.

Parameters:

weights (Float[Array, 'M'])

Return type:

Float[Array, ’M’]

class ptyrodactyl.electrons.forward.partial[source]

Bases: object

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

args

tuple of arguments to future partial calls

func

function object to use in future partial calls

keywords

dictionary of keyword arguments to future partial calls

ptyrodactyl.electrons.forward.propagation_func(imsize_y, imsize_x, thickness_ang, voltage_kV, calib_ang)[source]

Description

Calculates the complex propagation function that results in the phase shift of the exit wave when it travels from one slice to the next in the multislice algorithm

Parameters

  • imsize_y, (int):

    Size of the image of the propagator in y axis

  • imsize_x, (int):

    Size of the image of the propagator in x axis

  • thickness_ang, (scalar_number):

    Distance between the slices in angstroms

  • voltage_kV, (scalar_number):

    Accelerating voltage in kilovolts

  • calib_ang, (scalar_number):

    Calibration or pixel size in angstroms

Returns

  • prop (Complex[Array, “H W”]):

    The propagation function of the same size given by imsize

Flow

  • Generate frequency arrays directly using fftfreq

  • Create 2D meshgrid of frequencies

  • Calculate squared sum of frequencies

  • Calculate wavelength

  • Compute the propagation function

Parameters:
  • imsize_y (int)

  • imsize_x (int)

  • thickness_ang (Num[Array, ''])

  • voltage_kV (Num[Array, ''])

  • calib_ang (Float[Array, ''])

Return type:

Complex[Array, ’H W’]

ptyrodactyl.electrons.forward.shift_beam_fourier(beam, pos, calib_ang)[source]

Description

Shifts the beam to new position(s) using Fourier shifting.

Parameters

  • beam (Complex[Array, “H W M”]):

    The electron beam modes.

  • pos (Float[Array, “… 2”]):

    The (y, x) position(s) to shift to in pixels. Can be a single position [2] or multiple […, 2].

  • calib_ang (Float[Array, “*”]):

    The calibration in angstroms.

Returns

  • shifted_beams (Complex[Array, “… H W M”]):

    The shifted beam(s) for all position(s) and mode(s).

Flow

  • Convert positions from real space to Fourier space

  • Create phase ramps in Fourier space for all positions

  • Apply shifts to each mode for all positions

Parameters:
  • beam (Complex[Array, 'H W M'])

  • pos (Float[Array, '... 2'])

  • calib_ang (Float[Array, '*'])

Return type:

Complex[Array, ’… H W M’]

ptyrodactyl.electrons.forward.stem_4d(pot_slice, beam, pos_list, slice_thickness, voltage_kV, calib_ang)[source]

Description

Calculates the 4D-STEM pattern for multiple probe positions with sharding.

Parameters

  • pot_slice (Complex[Array, “H W S”]):

    The potential slices.

  • beam (Complex[Array, “H W M”]):

    The electron beam modes.

  • pos_list (Float[Array, “P 2”]):

    List of (y, x) probe positions in pixels.

  • slice_thickness (float):

    The thickness of each slice in angstroms.

  • voltage_kV (float):

    The accelerating voltage in kilovolts.

  • calib_ang (float):

    The calibration in angstroms.

  • mesh (Mesh):

    The device mesh for sharding.

Returns

  • stem_pattern (Float[Array, “P H W”]): The calculated 4D-STEM pattern.

Flow

  • Calculate the transmission function once

  • Shift the beam to all positions

  • Calculate CBED patterns for all positions

Parameters:
  • pot_slice (Complex[Array, 'H W S'])

  • beam (Complex[Array, 'H W M'])

  • pos_list (Float[Array, 'P 2'])

  • slice_thickness (float)

  • voltage_kV (float)

  • calib_ang (float)

Return type:

Float[Array, ’P H W’]

ptyrodactyl.electrons.forward.stem_4d_mixed_state(pot_slices, modes, mode_weights, pos_list, slice_thickness, voltage_kV, calib_ang)[source]

Calculates 4D-STEM pattern for multiple modes and slices with mixed state.

Parameters:
  • pot_slices (Complex[Array, 'H W S'])

  • modes (Complex[Array, 'H W M'])

  • mode_weights (Float[Array, 'M'])

  • pos_list (Float[Array, 'P 2'])

  • slice_thickness (Float[Array, ''])

  • voltage_kV (Float[Array, ''])

  • calib_ang (Float[Array, ''])

Return type:

Float[Array, ’P H W’]

ptyrodactyl.electrons.forward.stem_4d_multi(pot_slices, beam, pos_list, slice_thickness, voltage_kV, calib_ang)[source]

Calculates 4D-STEM pattern for multiple slices.

This function propagates the beam through multiple slices before calculating the final diffraction pattern.

Parameters:
  • pot_slices (Complex[Array, 'H W S'])

  • beam (Complex[Array, 'H W M'])

  • pos_list (Float[Array, 'P 2'])

  • slice_thickness (Float[Array, ''])

  • voltage_kV (Float[Array, ''])

  • calib_ang (Float[Array, ''])

Return type:

Float[Array, ’P H W’]

ptyrodactyl.electrons.forward.transmission_func(pot_slice, voltage_kV)[source]

Description

Calculates the complex transmission function from a single potential slice at a given electron accelerating voltage.

Because this is JAX - you assume that the input is clean, and you don’t need to check for negative or NaN values. Your preprocessing steps should check for them - not the function itself.

Parameters

  • pot_slice (Float[Array, “#a #b”]):

    potential slice in Kirkland units

  • voltage_kV (scalar_number):

    microscope operating voltage in kilo electronVolts

Returns

  • trans (Complex[Array, “#a #b”]):

    The transmission function of a single crystal slice

Flow

  • Calculate the electron energy in electronVolts

  • Calculate the wavelength in angstroms

  • Calculate the Einstein energy

  • Calculate the sigma value, which is the constant for the phase shift

  • Calculate the transmission function as a complex exponential

Parameters:
  • pot_slice (Float[Array, '#a #b'])

  • voltage_kV (Num[Array, ''])

Return type:

Complex[Array, ’’]

ptyrodactyl.electrons.forward.typechecker(obj=None, *, conf=BeartypeConf())

Decorate the passed beartypeable (i.e., pure-Python callable or class) with optimal type-checking dynamically generated unique to that beartypeable under the passed beartype configuration.

This decorator supports two distinct (albeit equally efficient) modes of operation:

  • Decoration mode. The caller activates this mode by passing this decorator a type-checkable object via the obj parameter; this decorator then creates and returns a new callable wrapping that object with optimal type-checking. Specifically:

    • If this object is a callable, this decorator creates and returns a new runtime type-checker (i.e., pure-Python function validating all parameters and returns of all calls to that callable against all PEP-compliant type hints annotating those parameters and returns). The type-checker returned by this decorator is:

      • Optimized uniquely for the passed callable.

      • Guaranteed to run in O(1) constant-time with negligible constant factors.

      • Type-check effectively instantaneously.

      • Add effectively no runtime overhead to the passed callable.

    • If the passed object is a class, this decorator iteratively applies itself to all annotated methods of this class by dynamically wrapping each such method with a runtime type-checker (as described previously).

  • Configuration mode. The caller activates this mode by passing this decorator a beartype configuration via the conf parameter; this decorator then creates and returns a new beartype decorator enabling that configuration. That decorator may then be called (in decoration mode) to create and return a new callable wrapping the passed type-checkable object with optimal type-checking configured by that configuration.

If optimizations are enabled by the active Python interpreter (e.g., due to option -O passed to this interpreter), this decorator silently reduces to a noop.

Parameters

objOptional[BeartypeableT]

Beartypeable (i.e., pure-Python callable or class) to be decorated. Defaults to None, in which case this decorator is in configuration rather than decoration mode. In configuration mode, this decorator creates and returns an efficiently cached private decorator that generically applies the passed beartype configuration to any beartypeable object passed to that decorator. Look… It just works.

confBeartypeConf, optional

Beartype configuration (i.e., self-caching dataclass encapsulating all settings configuring type-checking for the passed object). Defaults to BeartypeConf(), the default O(1) constant-time configuration.

Returns

BeartypeReturn

Either:

  • If in decoration mode (i.e., obj is not None` while ``conf is None) and:

    • If obj is a callable, a new callable wrapping that callable with dynamically generated type-checking.

    • If obj is a class, this existing class embellished with dynamically generated type-checking.

  • If in configuration mode (i.e., obj is None` while ``conf is not None), a new beartype decorator enabling this configuration.

Raises:
  • BeartypeConfException – If the passed configuration is not actually a configuration (i.e., instance of the BeartypeConf class).

  • BeartypeDecorHintException

    If any annotation on this callable is neither:

    • A PEP-compliant type (i.e., instance or class complying with a PEP supported by beartype), including:

      • PEP 484 types (i.e., instance or class declared by the stdlib typing module).

    • A PEP-noncompliant type (i.e., instance or class complying with beartype-specific semantics rather than a PEP), including:

      • Fully-qualified forward references (i.e., strings specified as fully-qualified classnames).

      • Tuple unions (i.e., tuples containing one or more classes and/or forward references).

  • BeartypePep563Exception – If PEP 563 is active for this callable and evaluating a postponed annotation (i.e., annotation whose value is a string) on this callable raises an exception (e.g., due to that annotation referring to local state no longer accessible from this deferred evaluation).

  • BeartypeDecorParamNameException – If the name of any parameter declared on this callable is prefixed by the reserved substring __beartype_.

  • BeartypeDecorWrappeeException

    If this callable is either:

    • Uncallable.

    • A class, which beartype currently fails to support.

    • A C-based callable (e.g., builtin, third-party C extension).

  • BeartypeDecorWrapperException – If this decorator erroneously generates a syntactically invalid wrapper function. This should never happen, but here we are, so this probably happened. Please submit an upstream issue with our issue tracker if you ever see this. (Thanks and abstruse apologies!)

Parameters:
  • obj (BeartypeableT | None)

  • conf (BeartypeConf)

Return type:

BeartypeableT | Callable[[BeartypeableT], BeartypeableT]

ptyrodactyl.electrons.forward.wavelength_ang(voltage_kV)[source]

Description

Calculates the relativistic electron wavelength in angstroms based on the microscope accelerating voltage.

Because this is JAX - you assume that the input is clean, and you don’t need to check for negative or NaN values. Your preprocessing steps should check for them - not the function itself.

Parameters

  • voltage_kV (num_type | Float[Array, “#a”]):

    The microscope accelerating voltage in kilo electronVolts

Returns

  • `in_angstroms (Float[Array, “*”]):

    The electron wavelength in angstroms

Flow

  • Calculate the electron wavelength in meters

  • Convert the wavelength to angstroms

Parameters:

voltage_kV (Num[Array, '#a'])

Return type:

Float[Array, ’#a’]

class ptyrodactyl.electrons.inverse.Callable

Bases: object

class ptyrodactyl.electrons.inverse.Complex(*args, **kwargs)

Bases: AbstractDtype

dtypes: Literal[_any_dtype] | list[str | Pattern] = ('complex64', 'complex128')
ptyrodactyl.electrons.inverse.Dict

alias of dict

ptyrodactyl.electrons.inverse.NamedTuple(typename, fields=None, /, **kwargs)[source]

Typed version of namedtuple.

Usage:

class Employee(NamedTuple):
    name: str
    id: int

This is equivalent to:

Employee = collections.namedtuple('Employee', ['name', 'id'])

The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:

Employee = NamedTuple('Employee', [('name', str), ('id', int)])
ptyrodactyl.electrons.inverse.Tuple

alias of tuple

ptyrodactyl.electrons.inverse.get_optimizer(optimizer_name)[source]
Parameters:

optimizer_name (str)

Return type:

Optimizer

ptyrodactyl.electrons.inverse.initialize_mixed_states(base_state, num_states, energy_spread=0.5, random_seed=42)[source]

Initialize mixed states for partial temporal coherence.

Args:

base_state: Base state (e.g., probe) num_states: Number of states in mixture energy_spread: Energy spread in eV (FWHM) random_seed: Random seed for initialization

Returns:

MixedState with initialized states and probabilities

Parameters:
  • base_state (Complex[Array, 'H W'])

  • num_states (int)

  • energy_spread (float)

  • random_seed (int)

Return type:

MixedQuantumStates

ptyrodactyl.electrons.inverse.initialize_probe_modes(base_probe, num_modes, random_seed=42)[source]

Initialize multiple probe modes from a base probe.

Args:

base_probe: Base probe function num_modes: Number of modes to generate random_seed: Random seed for initialization

Returns:

ProbeState with initialized modes and weights

Parameters:
  • base_probe (Complex[Array, 'H W'])

  • num_modes (int)

  • random_seed (int)

Return type:

ProbeState

ptyrodactyl.electrons.inverse.multi_mode_ptychography(experimental_4dstem, initial_pot_slices, initial_modes, mode_weights, pos_list, slice_thickness, voltage_kV, calib_ang, num_iterations=1000, learning_rate=0.001, loss_type='mse', optimizer_name='adam', scheduler_fn=None)[source]

Multi-mode ptychographic reconstruction with optional mixed state.

Args:

experimental_4dstem: Experimental 4D-STEM data initial_pot_slices: Initial guess for potential slices initial_modes: Initial guess for probe modes mode_weights: Initial weights for each mode pos_list: List of probe positions slice_thickness: Thickness of each slice voltage_kV: Accelerating voltage calib_ang: Calibration in angstroms num_iterations: Number of optimization iterations learning_rate: Initial learning rate loss_type: Type of loss function optimizer_name: Name of optimizer to use scheduler_fn: Optional learning rate scheduler

Returns:

Tuple of optimized potential slices, probe modes, and mode weights

Parameters:
  • experimental_4dstem (Float[Array, 'P H W'])

  • initial_pot_slices (Complex[Array, 'H W S'])

  • initial_modes (Complex[Array, 'H W M'])

  • mode_weights (Float[Array, 'M'])

  • pos_list (Float[Array, 'P 2'])

  • slice_thickness (Float[Array, ''])

  • voltage_kV (Float[Array, ''])

  • calib_ang (Float[Array, ''])

  • num_iterations (int)

  • learning_rate (float)

  • loss_type (str)

  • optimizer_name (str)

  • scheduler_fn (Callable[[LRSchedulerState], tuple[float, LRSchedulerState]] | None)

Return type:

tuple[Complex[Array, ’H W S’], Complex[Array, ’H W M’], Float[Array, ’M’]]

ptyrodactyl.electrons.inverse.multi_slice_ptychography(experimental_4dstem, initial_pot_slices, initial_beam, pos_list, slice_thickness, voltage_kV, calib_ang, num_iterations=1000, learning_rate=0.001, loss_type='mse', optimizer_name='adam', scheduler_fn=None)[source]

Multi-slice ptychographic reconstruction.

Args:

experimental_4dstem: Experimental 4D-STEM data initial_pot_slices: Initial guess for potential slices initial_beam: Initial guess for electron beam pos_list: List of probe positions slice_thickness: Thickness of each slice voltage_kV: Accelerating voltage calib_ang: Calibration in angstroms num_iterations: Number of optimization iterations learning_rate: Initial learning rate loss_type: Type of loss function optimizer_name: Name of optimizer to use scheduler_fn: Optional learning rate scheduler

Returns:

Tuple of optimized potential slices and beam

Parameters:
  • experimental_4dstem (Float[Array, 'P H W'])

  • initial_pot_slices (Complex[Array, 'H W S'])

  • initial_beam (Complex[Array, 'H W'])

  • pos_list (Float[Array, 'P 2'])

  • slice_thickness (Float[Array, ''])

  • voltage_kV (Float[Array, ''])

  • calib_ang (Float[Array, ''])

  • num_iterations (int)

  • learning_rate (float)

  • loss_type (str)

  • optimizer_name (str)

  • scheduler_fn (Callable[[LRSchedulerState], tuple[float, LRSchedulerState]] | None)

Return type:

tuple[Complex[Array, ’H W S’], Complex[Array, ’H W’]]

ptyrodactyl.electrons.inverse.single_slice_poscorrected(experimental_4dstem, initial_pot_slice, initial_beam, initial_pos_list, slice_thickness, voltage_kV, calib_ang, devices, num_iterations=1000, learning_rate=0.001, pos_learning_rate=0.1, loss_type='mse', optimizer_name='adam')[source]

Create and run an optimization routine for 4D-STEM reconstruction with position correction.

Args: - experimental_4dstem (Float[Array, “P H W”]):

Experimental 4D-STEM data.

  • initial_pot_slice (Complex[Array, “H W”]):

    Initial guess for potential slice.

  • initial_beam (Complex[Array, “H W”]):

    Initial guess for electron beam.

  • initial_pos_list (Float[Array, “P 2”]):

    Initial list of probe positions.

  • slice_thickness (Float[Array, “*”]):

    Thickness of each slice.

  • voltage_kV (Float[Array, “*”]):

    Accelerating voltage.

  • calib_ang (Float[Array, “*”]):

    Calibration in angstroms.

  • devices (jax.Array):

    Array of devices for sharding.

  • num_iterations (int):

    Number of optimization iterations.

  • learning_rate (float):

    Learning rate for potential slice and beam optimization.

  • pos_learning_rate (float):

    Learning rate for position optimization.

  • loss_type (str):

    Type of loss function to use.

Returns: - Tuple[Complex[Array, “H W”], Complex[Array, “H W”], Float[Array, “P 2”]]:

Optimized potential slice, beam, and corrected positions.

Parameters:
  • experimental_4dstem (Float[Array, 'P H W'])

  • initial_pot_slice (Complex[Array, 'H W'])

  • initial_beam (Complex[Array, 'H W'])

  • initial_pos_list (Float[Array, 'P 2'])

  • slice_thickness (Float[Array, '*'])

  • voltage_kV (Float[Array, '*'])

  • calib_ang (Float[Array, '*'])

  • devices (Array)

  • num_iterations (int)

  • learning_rate (float)

  • pos_learning_rate (float)

  • loss_type (str)

  • optimizer_name (str)

Return type:

tuple[Complex[Array, ’H W’], Complex[Array, ’H W’], Float[Array, ’P 2’]]

ptyrodactyl.electrons.inverse.single_slice_ptychography(experimental_4dstem, initial_pot_slice, initial_beam, pos_list, slice_thickness, voltage_kV, calib_ang, num_iterations=1000, learning_rate=0.001, loss_type='mse', optimizer_name='adam')[source]

Create and run an optimization routine for 4D-STEM reconstruction.

Args: - experimental_4dstem (Float[Array, “P H W”]):

Experimental 4D-STEM data.

  • initial_pot_slice (Complex[Array, “H W”]):

    Initial guess for potential slice.

  • initial_beam (Complex[Array, “H W”]):

    Initial guess for electron beam.

  • pos_list (Float[Array, “P 2”]):

    List of probe positions.

  • slice_thickness (Float[Array, “*”]):

    Thickness of each slice.

  • voltage_kV (Float[Array, “*”]):

    Accelerating voltage.

  • calib_ang (Float[Array, “*”]):

    Calibration in angstroms.

  • devices (jax.Array):

    Array of devices for sharding.

  • num_iterations (int):

    Number of optimization iterations.

  • learning_rate (float):

    Learning rate for optimization.

  • loss_type (str):

    Type of loss function to use.

Returns: - Tuple[Complex[Array, “H W”], Complex[Array, “H W”]]:

Optimized potential slice and beam.

Parameters:
  • experimental_4dstem (Float[Array, 'P H W'])

  • initial_pot_slice (Complex[Array, 'H W'])

  • initial_beam (Complex[Array, 'H W'])

  • pos_list (Float[Array, 'P 2'])

  • slice_thickness (Float[Array, '*'])

  • voltage_kV (Float[Array, '*'])

  • calib_ang (Float[Array, '*'])

  • num_iterations (int)

  • learning_rate (float)

  • loss_type (str)

  • optimizer_name (str)

Return type:

tuple[Complex[Array, ’H W’], Complex[Array, ’H W’]]

Light Microscopy Module

class ptyrodactyl.optics.epie.Complex(*args, **kwargs)

Bases: AbstractDtype

dtypes: Literal[_any_dtype] | list[str | Pattern] = ('complex64', 'complex128')
ptyrodactyl.optics.epie.NamedTuple(typename, fields=None, /, **kwargs)[source]

Typed version of namedtuple.

Usage:

class Employee(NamedTuple):
    name: str
    id: int

This is equivalent to:

Employee = collections.namedtuple('Employee', ['name', 'id'])

The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:

Employee = NamedTuple('Employee', [('name', str), ('id', int)])
ptyrodactyl.optics.epie.Tuple

alias of tuple

ptyrodactyl.optics.epie.jaxtyped(fn=sentinel, *, typechecker=sentinel)[source]

Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes.

!!! Example

``python # Import both the annotation and the `jaxtyped decorator from jaxtyping from jaxtyping import Array, Float, jaxtyped

# Use your favourite typechecker: usually one of the two lines below. from typeguard import typechecked as typechecker from beartype import beartype as typechecker

# Type-check a function @jaxtyped(typechecker=typechecker) def batch_outer_product(x: Float[Array, “b c1”],

y: Float[Array, “b c2”]

) -> Float[Array, “b c1 c2”]:

return x[:, :, None] * y[:, None, :]

# Type-check a dataclass from dataclasses import dataclass

@jaxtyped(typechecker=typechecker) @dataclass class MyDataclass:

x: int y: Float[Array, “b c”]

```

Arguments:

  • fn: The function or dataclass to decorate. In practice if you want to use

    dataclasses with JAX, then [equinox.Module](https://docs.kidger.site/equinox/api/module/module/) is our recommended approach: ```python import equinox as eqx

    @jaxtyped(typechecker=typechecker) class MyModule(eqx.Module):

    ```

  • typechecker: Keyword-only argument: the runtime type-checker to use. This should

    be a function decorator that will raise an exception if there is a type error, e.g. ```python @typechecker def f(x: int):

    pass

    f(“a string is not an integer”) # this line should raise an exception ``` Common choices are typechecker=beartype.beartype or typechecker=typeguard.typechecked. Can also be set as typechecker=None to skip automatic runtime type-checking, but still support manual isinstance checks inside the function body: ```python @jaxtyped(typechecker=None) def f(x):

    assert isinstance(x, Float[Array, “batch channel”])

    ```

Returns:

If fn is a function (including a staticmethod, classmethod, or property), then a wrapped function is returned.

If fn is a dataclass, then fn is returned directly, and additionally its __init__ method is wrapped and modified in-place.

!!! Info “Old syntax”

jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax: `python @jaxtyped @typechecker def f(...): ... ` This is still supported, but will now raise a warning recommending the jaxtyped(typechecker=typechecker) syntax discussed above. (Which will produce easier-to-debug error messages: under the hood, the new syntax more carefully manipulates the typechecker so as to determine where a type-check error arises.)

??? Info “Notes for advanced users”

Dynamic contexts:

Put precisely, the axis names in e.g. Float[Array, “batch channels”] and the structure names in e.g. PyTree[int, “T”] are all scoped to the thread-local dynamic context of a jaxtyped-wrapped function. If from within that function we then call another jaxtyped-wrapped function, then a new context is pushed to the stack. The axis sizes and PyTree structures of this inner function will then not be compared against the axis sizes and PyTree structures of the outer function. After the inner function returns then this inner context is popped from the stack, and the previous context is returned to.

isinstance:

Binding of a value against a name is done with an isinstance check, for example isinstance(jnp.zeros((3, 4)), Float[Array, “dim1 dim2”]) will bind dim1=3 and dim2=4. In practice these isinstance checks are usually done by the run-time typechecker typechecker that is supplied as an argument.

This can also be done manually: add isinstance checks inside a function body and they will contribute to the same collection of consistency checks as are performed by the typechecker on the arguments and return values. (Or you can forgo such a typechecker altogether – i.e. typechecker=None – and only do your own manual isinstance checks.)

Only isinstance checks that pass will contribute to the store of values; those that fail will not. As such it is safe to write e.g. assert not isinstance(x, Float32[Array, “foo”]).

Decoupling contexts from function calls:

If you would like to call a new function without creating a new dynamic context (and using the same set of axis and structure values), then simply do not add a jaxtyped decorator to your inner function, whilst continuing to perform type-checking in whatever way you prefer.

Conversely, if you would like a new dynamic context without calling a new function, then in addition to the usage discussed above, jaxtyped also supports being used as a context manager, by passing it the string “context”: ```python with jaxtyped(“context”):

assert isinstance(x, Float[Array, “batch channel”])

``` This is equivalent to placing this code inside a new function wrapped in jaxtyped(typechecker=None). Usage like this is very rare; it’s mostly only useful when working at the global scope.

ptyrodactyl.optics.epie.register_pytree_node_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

This function is a thin wrapper around register_pytree_node, and provides a class-oriented interface.

Args:

cls: a type to register as a pytree

Returns:

The input class cls is returned unchanged after being added to JAX’s pytree registry. This return value allows register_pytree_node_class to be used as a decorator.

See also:
Examples:

Here we’ll define a custom container that will be compatible with jax.jit() and other JAX transformations:

>>> import jax
>>> @jax.tree_util.register_pytree_node_class
... class MyContainer:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten(self):
...     return ((self.x, self.y), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)
...
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
>>> def f(m):
...   return m.x + 2 * m.y
>>> jax.jit(f)(m)
Array([0., 2., 4., 6.], dtype=float32)
Parameters:

cls (Typ)

Return type:

Typ

class ptyrodactyl.optics.lenses.Complex(*args, **kwargs)

Bases: AbstractDtype

dtypes: Literal[_any_dtype] | list[str | Pattern] = ('complex64', 'complex128')
class ptyrodactyl.optics.lenses.LensParams(focal_length, diameter, n, center_thickness, R1, R2)[source]

Bases: NamedTuple

Description

PyTree structure for lens parameters

Parameters:
  • focal_length (Float[Array, ''])

  • diameter (Float[Array, ''])

  • n (Float[Array, ''])

  • center_thickness (Float[Array, ''])

  • R1 (Float[Array, ''])

  • R2 (Float[Array, ''])

- `focal_length` (Float[Array, ""])

Focal length of the lens in meters

- `diameter` (Float[Array, ""])

Diameter of the lens in meters

- `n` (Float[Array, ""])

Refractive index of the lens material

- `center_thickness` (Float[Array, ""])

Thickness at the center of the lens in meters

- `R1` (Float[Array, ""])

Radius of curvature of the first surface in meters (positive for convex)

- `R2` (Float[Array, ""])

Radius of curvature of the second surface in meters (positive for convex)

Notes

This class is registered as a PyTree node, making it compatible with JAX transformations like jit, grad, and vmap. The auxiliary data in tree_flatten is None as all relevant data is stored in JAX arrays.

R1: Float[Array, '']

Alias for field number 4

R2: Float[Array, '']

Alias for field number 5

center_thickness: Float[Array, '']

Alias for field number 3

diameter: Float[Array, '']

Alias for field number 1

focal_length: Float[Array, '']

Alias for field number 0

n: Float[Array, '']

Alias for field number 2

tree_flatten()[source]
classmethod tree_unflatten(aux_data, children)[source]
ptyrodactyl.optics.lenses.NamedTuple(typename, fields=None, /, **kwargs)[source]

Typed version of namedtuple.

Usage:

class Employee(NamedTuple):
    name: str
    id: int

This is equivalent to:

Employee = collections.namedtuple('Employee', ['name', 'id'])

The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:

Employee = NamedTuple('Employee', [('name', str), ('id', int)])
ptyrodactyl.optics.lenses.Tuple

alias of tuple

ptyrodactyl.optics.lenses.create_lens_phase(X, Y, params, wavelength)[source]

Description

Create the phase profile and transmission mask for a lens.

Parameters

  • X (Float[Array, “H W”]):

    X coordinates grid

  • Y (Float[Array, “H W”]):

    Y coordinates grid

  • params (LensParams):

    Lens parameters

  • wavelength (Float[Array, “”]):

    Wavelength of light

Returns

  • phase_profile (Float[Array, “H W”]):

    Phase profile of the lens

  • transmission (Float[Array, “H W”]):

    Transmission mask of the lens

Flow

  • Calculate radial coordinates

  • Calculate thickness profile

  • Calculate phase profile

  • Create transmission mask

  • Return phase and transmission

Parameters:
  • X (Float[Array, 'H W'])

  • Y (Float[Array, 'H W'])

  • params (LensParams)

  • wavelength (Float[Array, ''])

Return type:

tuple[Float[Array, ’H W’], Float[Array, ’H W’]]

ptyrodactyl.optics.lenses.double_concave_lens(focal_length, diameter, n, center_thickness, R_ratio=Array(1., dtype=float64, weak_type=True))[source]

Description

Create parameters for a double concave lens.

Parameters

  • focal_length (Float[Array, “”]):

    Desired focal length

  • diameter (Float[Array, “”]):

    Lens diameter

  • n (Float[Array, “”]):

    Refractive index

  • center_thickness (Float[Array, “”]):

    Center thickness

  • R_ratio (Optional[Float[Array, “”]]):

    Ratio of R2/R1. default is 1.0 for symmetric lens

Returns

  • params (LensParams):

    Lens parameters

Flow

  • Calculate R1 using lensmaker’s equation

  • Calculate R2 using R_ratio

  • Create and return LensParams

Parameters:
  • focal_length (Float[Array, ''])

  • diameter (Float[Array, ''])

  • n (Float[Array, ''])

  • center_thickness (Float[Array, ''])

  • R_ratio (Float[Array, ''] | None)

Return type:

LensParams

ptyrodactyl.optics.lenses.double_convex_lens(focal_length, diameter, n, center_thickness, R_ratio=Array(1., dtype=float64, weak_type=True))[source]

Description

Create parameters for a double convex lens.

Parameters

  • focal_length (Float[Array, “”]):

    Desired focal length

  • diameter (Float[Array, “”]):

    Lens diameter

  • n (Float[Array, “”]):

    Refractive index

  • center_thickness (Float[Array, “”]):

    Center thickness

  • R_ratio (Optional[Float[Array, “”]]):

    Ratio of R2/R1. default is 1.0 for symmetric lens

Returns

  • params (LensParams):

    Lens parameters

Flow

  • Calculate R1 using lensmaker’s equation

  • Calculate R2 using R_ratio

  • Create and return LensParams

Parameters:
  • focal_length (Float[Array, ''])

  • diameter (Float[Array, ''])

  • n (Float[Array, ''])

  • center_thickness (Float[Array, ''])

  • R_ratio (Float[Array, ''] | None)

Return type:

LensParams

ptyrodactyl.optics.lenses.jaxtyped(fn=sentinel, *, typechecker=sentinel)[source]

Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes.

!!! Example

``python # Import both the annotation and the `jaxtyped decorator from jaxtyping from jaxtyping import Array, Float, jaxtyped

# Use your favourite typechecker: usually one of the two lines below. from typeguard import typechecked as typechecker from beartype import beartype as typechecker

# Type-check a function @jaxtyped(typechecker=typechecker) def batch_outer_product(x: Float[Array, “b c1”],

y: Float[Array, “b c2”]

) -> Float[Array, “b c1 c2”]:

return x[:, :, None] * y[:, None, :]

# Type-check a dataclass from dataclasses import dataclass

@jaxtyped(typechecker=typechecker) @dataclass class MyDataclass:

x: int y: Float[Array, “b c”]

```

Arguments:

  • fn: The function or dataclass to decorate. In practice if you want to use

    dataclasses with JAX, then [equinox.Module](https://docs.kidger.site/equinox/api/module/module/) is our recommended approach: ```python import equinox as eqx

    @jaxtyped(typechecker=typechecker) class MyModule(eqx.Module):

    ```

  • typechecker: Keyword-only argument: the runtime type-checker to use. This should

    be a function decorator that will raise an exception if there is a type error, e.g. ```python @typechecker def f(x: int):

    pass

    f(“a string is not an integer”) # this line should raise an exception ``` Common choices are typechecker=beartype.beartype or typechecker=typeguard.typechecked. Can also be set as typechecker=None to skip automatic runtime type-checking, but still support manual isinstance checks inside the function body: ```python @jaxtyped(typechecker=None) def f(x):

    assert isinstance(x, Float[Array, “batch channel”])

    ```

Returns:

If fn is a function (including a staticmethod, classmethod, or property), then a wrapped function is returned.

If fn is a dataclass, then fn is returned directly, and additionally its __init__ method is wrapped and modified in-place.

!!! Info “Old syntax”

jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax: `python @jaxtyped @typechecker def f(...): ... ` This is still supported, but will now raise a warning recommending the jaxtyped(typechecker=typechecker) syntax discussed above. (Which will produce easier-to-debug error messages: under the hood, the new syntax more carefully manipulates the typechecker so as to determine where a type-check error arises.)

??? Info “Notes for advanced users”

Dynamic contexts:

Put precisely, the axis names in e.g. Float[Array, “batch channels”] and the structure names in e.g. PyTree[int, “T”] are all scoped to the thread-local dynamic context of a jaxtyped-wrapped function. If from within that function we then call another jaxtyped-wrapped function, then a new context is pushed to the stack. The axis sizes and PyTree structures of this inner function will then not be compared against the axis sizes and PyTree structures of the outer function. After the inner function returns then this inner context is popped from the stack, and the previous context is returned to.

isinstance:

Binding of a value against a name is done with an isinstance check, for example isinstance(jnp.zeros((3, 4)), Float[Array, “dim1 dim2”]) will bind dim1=3 and dim2=4. In practice these isinstance checks are usually done by the run-time typechecker typechecker that is supplied as an argument.

This can also be done manually: add isinstance checks inside a function body and they will contribute to the same collection of consistency checks as are performed by the typechecker on the arguments and return values. (Or you can forgo such a typechecker altogether – i.e. typechecker=None – and only do your own manual isinstance checks.)

Only isinstance checks that pass will contribute to the store of values; those that fail will not. As such it is safe to write e.g. assert not isinstance(x, Float32[Array, “foo”]).

Decoupling contexts from function calls:

If you would like to call a new function without creating a new dynamic context (and using the same set of axis and structure values), then simply do not add a jaxtyped decorator to your inner function, whilst continuing to perform type-checking in whatever way you prefer.

Conversely, if you would like a new dynamic context without calling a new function, then in addition to the usage discussed above, jaxtyped also supports being used as a context manager, by passing it the string “context”: ```python with jaxtyped(“context”):

assert isinstance(x, Float[Array, “batch channel”])

``` This is equivalent to placing this code inside a new function wrapped in jaxtyped(typechecker=None). Usage like this is very rare; it’s mostly only useful when working at the global scope.

ptyrodactyl.optics.lenses.lens_focal_length(n, R1, R2)[source]

Description

Calculate the focal length of a lens using the lensmaker’s equation.

Parameters

  • n (Float[Array, “”]):

    Refractive index of the lens material

  • R1 (Float[Array, “”]):

    Radius of curvature of the first surface (positive for convex)

  • R2 (Float[Array, “”]):

    Radius of curvature of the second surface (positive for convex)

Returns

  • f (Float[Array, “”]):

    Focal length of the lens

Flow

  • Apply the lensmaker’s equation

  • Return the calculated focal length

Parameters:
  • n (Float[Array, ''])

  • R1 (Float[Array, ''])

  • R2 (Float[Array, ''])

Return type:

Float[Array, ’’]

ptyrodactyl.optics.lenses.lens_thickness_profile(r, R1, R2, center_thickness, diameter)[source]

Description

Calculate the thickness profile of a lens.

Parameters

  • r (Float[Array, “H W”]):

    Radial distance from the optical axis

  • R1 (Float[Array, “”]):

    Radius of curvature of the first surface

  • R2 (Float[Array, “”]):

    Radius of curvature of the second surface

  • center_thickness (Float[Array, “”]):

    Thickness at the center of the lens

  • diameter (Float[Array, “”]):

    Diameter of the lens

Returns

  • thickness (Float[Array, “H W”]):

    Thickness profile of the lens

Flow

  • Calculate surface sag for both surfaces

  • Combine sags with center thickness

  • Apply aperture mask

  • Return thickness profile

Parameters:
  • r (Float[Array, 'H W'])

  • R1 (Float[Array, ''])

  • R2 (Float[Array, ''])

  • center_thickness (Float[Array, ''])

  • diameter (Float[Array, ''])

Return type:

Float[Array, ’H W’]

ptyrodactyl.optics.lenses.meniscus_lens(focal_length, diameter, n, center_thickness, R_ratio, convex_first=Array(True, dtype=bool))[source]

Description

Create parameters for a meniscus (concavo-convex) lens. For a meniscus lens, one surface is convex (positive R) and one is concave (negative R).

Parameters

  • focal_length (Float[Array, “”]):

    Desired focal length in meters

  • diameter (Float[Array, “”]):

    Lens diameter in meters

  • n (Float[Array, “”]):

    Refractive index of lens material

  • center_thickness (Float[Array, “”]):

    Center thickness in meters

  • R_ratio (Float[Array, “”]):

    Absolute ratio of R2/R1

  • convex_first (Bool[Array, “”]):

    If True, first surface is convex (default: True)

Returns

  • params (LensParams):

    Lens parameters

Flow

  • Calculate magnitude of R1 using lensmaker’s equation

  • Calculate R2 magnitude using R_ratio

  • Assign correct signs based on convex_first

  • Create and return LensParams

Parameters:
  • focal_length (Float[Array, ''])

  • diameter (Float[Array, ''])

  • n (Float[Array, ''])

  • center_thickness (Float[Array, ''])

  • R_ratio (Float[Array, ''])

  • convex_first (Bool[Array, ''] | None)

Return type:

LensParams

ptyrodactyl.optics.lenses.plano_concave_lens(focal_length, diameter, n, center_thickness, concave_first=Array(True, dtype=bool))[source]

Description

Create parameters for a plano-concave lens.

Parameters

  • focal_length (Float[Array, “”]):

    Desired focal length

  • diameter (Float[Array, “”]):

    Lens diameter

  • n (Float[Array, “”]):

    Refractive index

  • center_thickness (Float[Array, “”]):

    Center thickness

  • R_ratio (Optional[Float[Array, “”]]):

    Ratio of R2/R1. default is 1.0 for symmetric lens

  • concave_first (Optional[Bool[Array, “”]]):

    If True, first surface is concave (default: True)

Returns

  • params (LensParams):

    Lens parameters

Flow

  • Calculate R for curved surface

  • Set other R to infinity (flat surface)

  • Create and return LensParams

Parameters:
  • focal_length (Float[Array, ''])

  • diameter (Float[Array, ''])

  • n (Float[Array, ''])

  • center_thickness (Float[Array, ''])

  • concave_first (Bool[Array, ''] | None)

Return type:

LensParams

ptyrodactyl.optics.lenses.plano_convex_lens(focal_length, diameter, n, center_thickness, convex_first=Array(True, dtype=bool))[source]

Description

Create parameters for a plano-convex lens.

Parameters

  • focal_length (Float[Array, “”]):

    Desired focal length

  • diameter (Float[Array, “”]):

    Lens diameter

  • n (Float[Array, “”]):

    Refractive index

  • center_thickness (Float[Array, “”]):

    Center thickness

  • R_ratio (Optional[Float[Array, “”]]):

    Ratio of R2/R1. default is 1.0 for symmetric lens

  • convex_first (Optional[Bool[Array, “”]]):

    If True, first surface is convex. Default: True

Returns

  • params (LensParams):

    Lens parameters

Flow

  • Calculate R for curved surface

  • Set other R to infinity (flat surface)

  • Create and return LensParams

Parameters:
  • focal_length (Float[Array, ''])

  • diameter (Float[Array, ''])

  • n (Float[Array, ''])

  • center_thickness (Float[Array, ''])

  • convex_first (Bool[Array, ''] | None)

Return type:

LensParams

ptyrodactyl.optics.lenses.propagate_through_lens(field, phase_profile, transmission)[source]

Description

Propagate a field through a lens.

Parameters

  • field (Complex[Array, “H W”]):

    Input complex field

  • phase_profile (Float[Array, “H W”]):

    Phase profile of the lens

  • transmission (Float[Array, “H W”]):

    Transmission mask of the lens

Returns

  • output_field (Complex[Array, “H W”]):

    Field after passing through the lens

Flow

  • Apply transmission mask

  • Add phase profile

  • Return modified field

Parameters:
  • field (Complex[Array, 'H W'])

  • phase_profile (Float[Array, 'H W'])

  • transmission (Float[Array, 'H W'])

Return type:

Complex[Array, ’H W’]

ptyrodactyl.optics.lenses.register_pytree_node_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

This function is a thin wrapper around register_pytree_node, and provides a class-oriented interface.

Args:

cls: a type to register as a pytree

Returns:

The input class cls is returned unchanged after being added to JAX’s pytree registry. This return value allows register_pytree_node_class to be used as a decorator.

See also:
Examples:

Here we’ll define a custom container that will be compatible with jax.jit() and other JAX transformations:

>>> import jax
>>> @jax.tree_util.register_pytree_node_class
... class MyContainer:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten(self):
...     return ((self.x, self.y), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)
...
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
>>> def f(m):
...   return m.x + 2 * m.y
>>> jax.jit(f)(m)
Array([0., 2., 4., 6.], dtype=float32)
Parameters:

cls (Typ)

Return type:

Typ

Tools Module

class ptyrodactyl.tools.loss_functions.Any(*args, **kwargs)[source]

Bases: object

Special type indicating an unconstrained type.

  • Any is compatible with every type.

  • Any assumed to have all methods.

  • All values assumed to be instances of Any.

Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.

class ptyrodactyl.tools.loss_functions.Complex(*args, **kwargs)

Bases: AbstractDtype

dtypes: Literal[_any_dtype] | list[str | Pattern] = ('complex64', 'complex128')
class ptyrodactyl.tools.loss_functions.PyTree(*args, **kwargs)

Bases: object

Represents a PyTree.

Annotations of the following sorts are supported: `python a: PyTree b: PyTree[LeafType] c: PyTree[LeafType, "T"] d: PyTree[LeafType, "S T"] e: PyTree[LeafType, "... T"] f: PyTree[LeafType, "T ..."] `

These correspond to:

  1. A plain PyTree can be used an annotation, in which case PyTree is simply a

    suggestively-named alternative to Any. ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))

  2. PyTree[LeafType] denotes a PyTree all of whose leaves match LeafType. For

    example, PyTree[int] or PyTree[Union[str, Float32[Array, “b c”]]].

  3. A structure name can also be passed. In this case

    jax.tree_util.tree_structure(…) will be called, and bound to the structure name. This can be used to mark that multiple PyTrees all have the same structure: ```python def f(x: PyTree[int, “T”], y: PyTree[int, “T”]):

    ``` Structures are bound to names in the same way as array shape annotations, i.e. within the thread-local dynamic context of a [jaxtyping.jaxtyped][] decorator.

  4. A composite structure can be declared. In this case the variable must have a PyTree

    structure each to the composition of multiple previously-bound PyTree structures. For example: ```python def f(x: PyTree[int, “T”], y: PyTree[int, “S”], z: PyTree[int, “S T”]):

    x = (1, 2) y = {“key”: 3} z = {“key”: (4, 5)} # structure is the composition of the structures of y and z f(x, y, z) ``` When performing runtime type-checking, all the individual pieces must have already been bound to structures, otherwise the composite structure check will throw an error.

  5. A structure can begin with a , to denote that the lower levels of the PyTree

    must match the declared structure, but the upper levels can be arbitrary. As in the previous case, all named pieces must already have been seen and their structures bound.

  6. A structure can end with a , to denote that the PyTree must be a prefix of the

    declared structure, but the lower levels can be arbitrary. As in the previous two cases, all named pieces must already have been seen and their structures bound.

ptyrodactyl.tools.loss_functions.create_loss_function(forward_function, experimental_data, loss_type='mae')[source]

Create a JIT-compatible loss function for comparing model output with experimental data.

This function returns a new function that computes the loss between the output of a forward model and experimental data. The returned function is JIT-compatible and can be used with various optimization algorithms.

Args: - forward_function (Callable[…, Array]):

The forward model function (e.g., stem_4d).

  • experimental_data (Array):

    The experimental data to compare against.

  • loss_type (Literal[“mae”, “mse”, “rmse”]):

    The type of loss to use. Options are “mae” (Mean Absolute Error), “mse” (Mean Squared Error), or “rmse” (Root Mean Squared Error). Default is “mae”.

Loss Functions: - mae_loss:

Mean Absolute Error loss function.

  • mse_loss:

    Mean Squared Error loss function.

  • rmse_loss:

    Root Mean Squared Error loss function.

Returns: - loss_fn (Callable[[PyTree, …], Float[Array, “”]]):

A JIT-compatible function that computes the loss given the model parameters and any additional arguments required by the forward function.

Parameters:
Return type:

Callable[[…], Float[Array, ’’]]

class ptyrodactyl.tools.optimizers.Any(*args, **kwargs)[source]

Bases: object

Special type indicating an unconstrained type.

  • Any is compatible with every type.

  • Any assumed to have all methods.

  • All values assumed to be instances of Any.

Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.

class ptyrodactyl.tools.optimizers.Callable

Bases: object

class ptyrodactyl.tools.optimizers.Complex(*args, **kwargs)

Bases: AbstractDtype

dtypes: Literal[_any_dtype] | list[str | Pattern] = ('complex64', 'complex128')
class ptyrodactyl.tools.optimizers.LRSchedulerState(step, learning_rate, initial_lr)[source]

Bases: NamedTuple

State maintained by learning rate schedulers.

Parameters:
initial_lr: float

Alias for field number 2

learning_rate: float

Alias for field number 1

step: int

Alias for field number 0

ptyrodactyl.tools.optimizers.NamedTuple(typename, fields=None, /, **kwargs)[source]

Typed version of namedtuple.

Usage:

class Employee(NamedTuple):
    name: str
    id: int

This is equivalent to:

Employee = collections.namedtuple('Employee', ['name', 'id'])

The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:

Employee = NamedTuple('Employee', [('name', str), ('id', int)])
class ptyrodactyl.tools.optimizers.Optimizer(init, update)[source]

Bases: NamedTuple

Parameters:
init: Callable

Alias for field number 0

update: Callable

Alias for field number 1

class ptyrodactyl.tools.optimizers.OptimizerState(m, v, step)[source]

Bases: NamedTuple

Parameters:
m: Array

Alias for field number 0

step: Array

Alias for field number 2

v: Array

Alias for field number 1

class ptyrodactyl.tools.optimizers.Sequence

Bases: Reversible, Collection

All the operations on a read-only sequence.

Concrete subclasses must override __new__ or __init__, __getitem__, and __len__.

count(value) integer -- return number of occurrences of value
index(value[, start[, stop]]) integer -- return first index of value.

Raises ValueError if the value is not present.

Supporting start and stop arguments is optional, but recommended.

ptyrodactyl.tools.optimizers.Tuple

alias of tuple

ptyrodactyl.tools.optimizers.adagrad_update(params, grads, state, learning_rate=0.01, eps=1e-08)[source]
Parameters:
Return type:

tuple[Complex[Array, ’…’], OptimizerState]

ptyrodactyl.tools.optimizers.adam_update(params, grads, state, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08)[source]
Parameters:
Return type:

tuple[Complex[Array, ’…’], OptimizerState]

ptyrodactyl.tools.optimizers.complex_adagrad(params, grads, state, learning_rate=0.01, eps=1e-08)[source]

Complex-valued Adagrad optimizer based on Wirtinger derivatives.

This function performs one step of the Adagrad optimization algorithm for complex-valued parameters.

Args: - params (Complex[Array, “…”]):

Current complex-valued parameters.

  • grads (Complex[Array, “…”]):

    Complex-valued gradients.

  • state (Complex[Array, “…”]):

    Optimizer state (accumulated squared gradients).

  • learning_rate (float):

    Learning rate (default: 0.01).

  • eps (float):

    Small value to avoid division by zero (default: 1e-8).

Returns: - new_params (Complex[Array, “…”]): Updated parameters. - new_state (Complex[Array, “…”]): Updated optimizer state.

Parameters:
Return type:

tuple[Complex[Array, ’…’], Complex[Array, ’…’]]

ptyrodactyl.tools.optimizers.complex_adam(params, grads, state, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08)[source]

Complex-valued Adam optimizer based on Wirtinger derivatives.

This function performs one step of the Adam optimization algorithm for complex-valued parameters.

Args: - params (Complex[Array, “…”]):

Current complex-valued parameters.

  • grads (Complex[Array, “…”]):

    Complex-valued gradients.

  • state (Tuple[Complex[Array, “…”], Complex[Array, “…”], int]):

    Optimizer state (first moment, second moment, timestep).

  • learning_rate (float):

    Learning rate (default: 0.001).

  • beta1 (float):

    Exponential decay rate for first moment estimates (default: 0.9).

  • beta2 (float):

    Exponential decay rate for second moment estimates (default: 0.999).

  • eps (float):

    Small value to avoid division by zero (default: 1e-8).

Returns: - new_params (Complex[Array, “…”]):

Updated parameters.

  • new_state (Tuple[Complex[Array, “…”], Complex[Array, “…”], int]):

    Updated optimizer state.

Parameters:
Return type:

tuple[Complex[Array, ’…’], tuple[Complex[Array, ’…’], Complex[Array, ’…’], int]]

ptyrodactyl.tools.optimizers.complex_rmsprop(params, grads, state, learning_rate=0.001, decay_rate=0.9, eps=1e-08)[source]

Complex-valued RMSprop optimizer based on Wirtinger derivatives.

This function performs one step of the RMSprop optimization algorithm for complex-valued parameters.

Args: - params (Complex[Array, “…”]):

Current complex-valued parameters.

  • grads (Complex[Array, “…”]):

    Complex-valued gradients.

  • state (Complex[Array, “…”]):

    Optimizer state (moving average of squared gradients).

  • learning_rate (float):

    Learning rate (default: 0.001).

  • decay_rate (float):

    Decay rate for moving average (default: 0.9).

  • eps (float):

    Small value to avoid division by zero (default: 1e-8).

Returns: - new_params (Complex[Array, “…”]): Updated parameters. - new_state (Complex[Array, “…”]): Updated optimizer state.

Parameters:
Return type:

tuple[Complex[Array, ’…’], Complex[Array, ’…’]]

ptyrodactyl.tools.optimizers.create_cosine_scheduler(total_steps, final_lr_factor=0.01)[source]

Creates a cosine learning rate scheduler.

Args:

total_steps: Total number of optimization steps final_lr_factor: Final learning rate as a fraction of initial

Parameters:
  • total_steps (int)

  • final_lr_factor (float)

Return type:

Callable[[LRSchedulerState], tuple[float, LRSchedulerState]]

ptyrodactyl.tools.optimizers.create_step_scheduler(step_size, gamma=0.1)[source]

Creates a step decay scheduler that reduces learning rate by gamma every step_size steps.

Args:

step_size: Number of steps between learning rate drops gamma: Multiplicative factor for learning rate decay

Parameters:
Return type:

Callable[[LRSchedulerState], tuple[float, LRSchedulerState]]

ptyrodactyl.tools.optimizers.create_warmup_cosine_scheduler(total_steps, warmup_steps, final_lr_factor=0.01)[source]

Creates a scheduler with linear warmup followed by cosine decay.

Args:

total_steps: Total number of optimization steps warmup_steps: Number of warmup steps final_lr_factor: Final learning rate as a fraction of initial

Parameters:
  • total_steps (int)

  • warmup_steps (int)

  • final_lr_factor (float)

Return type:

Callable[[LRSchedulerState], tuple[float, LRSchedulerState]]

ptyrodactyl.tools.optimizers.init_adagrad(shape)[source]
Parameters:

shape (tuple)

Return type:

OptimizerState

ptyrodactyl.tools.optimizers.init_adam(shape)[source]
Parameters:

shape (tuple)

Return type:

OptimizerState

ptyrodactyl.tools.optimizers.init_rmsprop(shape)[source]
Parameters:

shape (tuple)

Return type:

OptimizerState

ptyrodactyl.tools.optimizers.init_scheduler_state(initial_lr)[source]

Initialize scheduler state with given learning rate.

Parameters:

initial_lr (float)

Return type:

LRSchedulerState

ptyrodactyl.tools.optimizers.rmsprop_update(params, grads, state, learning_rate=0.001, decay_rate=0.9, eps=1e-08)[source]
Parameters:
Return type:

tuple[Complex[Array, ’…’], OptimizerState]

ptyrodactyl.tools.optimizers.wirtinger_grad(f, argnums=0)[source]

Compute the Wirtinger gradient of a complex-valued function.

This function returns a new function that computes the Wirtinger gradient of the input function f with respect to the specified argument(s).

Args: - f (Callable[…, Float[Array, “…”]]):

A complex-valued function to differentiate.

  • argnums (Union[int, Sequence[int]]):

    Specifies which argument(s) to compute the gradient with respect to. Can be an int or a sequence of ints. Default is 0.

Returns: - grad_f (Callable[…, Union[Complex[Array, “…”], Tuple[Complex[Array, “…”], …]]]):

A function that computes the Wirtinger gradient of f with respect to the specified argument(s).

Parameters:
Return type:

Callable[[…], Complex[Array, ’…’] | tuple[Complex[Array, ’…’], …]]

class ptyrodactyl.tools.parallel.Mesh(devices, axis_names, *, axis_types=None)[source]

Bases: ContextDecorator

Declare the hardware resources available in the scope of this manager.

In particular, all axis_names become valid resource names inside the managed block and can be used e.g. in the in_axis_resources argument of jax.experimental.pjit.pjit(). Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)

If you are compiling in multiple threads, make sure that the with Mesh context manager is inside the function that the threads will execute.

Args:
devices: A NumPy ndarray object containing JAX device objects (as

obtained e.g. from jax.devices()).

axis_names: A sequence of resource axis names to be assigned to the

dimensions of the devices argument. Its length should match the rank of devices.

Examples:

>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
Parameters:
property abstract_mesh
axis_names: tuple[Any, ...]
property axis_sizes: tuple[int, ...]
axis_types: dict[AxisTypes, str | tuple[str, ...]]
property device_ids
devices: ndarray
property empty
property is_multi_process
property local_devices
property local_mesh
property shape
property shape_tuple
property size
class ptyrodactyl.tools.parallel.NamedSharding(*args, **kwargs)

Bases: Sharding

A NamedSharding expresses sharding using named axes.

A NamedSharding is a pair of a Mesh of devices and PartitionSpec which describes how to shard an array across that mesh.

A Mesh is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g. 'x' or 'y'.

A PartitionSpec is a tuple, whose elements can be a None, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, PartitionSpec('x', 'y') says that the first dimension of data is sharded across x axis of the mesh, and the second dimension is sharded across y axis of the mesh.

The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how Mesh and PartitionSpec are used.

Args:

mesh: A jax.sharding.Mesh object. spec: A jax.sharding.PartitionSpec object.

Examples:

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property addressable_devices: set[Device]

The set of devices in the Sharding that are addressable by the current process.

check_compatible_aval(aval_shape)[source]
Parameters:

aval_shape (tuple[int, ...])

Return type:

None

property device_set: set[Device]

The set of devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

property is_fully_addressable: bool

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all of the devices named in the Sharding. is_fully_addressable is equivalent to “is_local” in multi-process JAX.

property is_fully_replicated: bool

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of the entire data.

property memory_kind: str | None

Returns the memory kind of the sharding.

property mesh

(self) -> object

property num_devices: int

Number of devices that the sharding contains.

property spec

(self) -> object

with_memory_kind(kind)[source]

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

NamedSharding

with_spec(spec)[source]
Parameters:

spec (PartitionSpec | Sequence[Any])

Return type:

NamedSharding

class ptyrodactyl.tools.parallel.PartitionSpec(*partitions)[source]

Bases: tuple

Tuple describing how to partition an array across a mesh of devices.

Each element is either None, a string, or a tuple of strings. See the documentation of jax.sharding.NamedSharding for more details.

This class exists so JAX’s pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees.

UNCONSTRAINED = UNCONSTRAINED
ptyrodactyl.tools.parallel.shard_array(input_array, shard_axes, devices=None)[source]

Shards an array across specified axes and devices.

Args: - input_array (Array):

The input input_array to be sharded.

  • shard_axes (Union[int, Sequence[int]]):

    The axis or axes to shard along. Use -1 or sequence of -1s to not shard along any axis.

  • devices (Sequence[jax.Device], optional):

    The devices to shard across. If None, uses all available devices.

Returns: - sharded_array (Array):

The sharded array.

Parameters:
Return type:

Array