Source code for ptyrodactyl.tools.parallel

from typing import Sequence, Union

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jaxtyping import Array


[docs] def shard_array( input_array: Array, shard_axes: Union[int, Sequence[int]], devices: Sequence[jax.Device] = None, ) -> Array: """ Shards an array across specified axes and devices. Args: - `input_array` (Array): The input input_array to be sharded. - `shard_axes` (Union[int, Sequence[int]]): The axis or axes to shard along. Use -1 or sequence of -1s to not shard along any axis. - `devices` (Sequence[jax.Device], optional): The devices to shard across. If None, uses all available devices. Returns: - `sharded_array` (Array): The sharded array. """ if devices is None: devices = jax.devices() # Ensure shard_axes is a sequence if isinstance(shard_axes, int): shard_axes = [shard_axes] # Create a mesh mesh = Mesh(devices, ("devices",)) # Create PartitionSpec pspec = [None] * input_array.ndim for ax in shard_axes: if ax != -1 and ax < input_array.ndim: pspec[ax] = "devices" pspec = PartitionSpec(*pspec) # Create NamedSharding sharding = NamedSharding(mesh, pspec) # Shard the input_array with mesh: sharded_array = jax.device_put(input_array, sharding) return sharded_array