from beartype.typing import Any, Callable, NamedTuple, Sequence, Tuple, Union
import jax
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float
[docs]
class LRSchedulerState(NamedTuple):
"""State maintained by learning rate schedulers."""
step: int
learning_rate: float
initial_lr: float
SchedulerFn = Callable[[LRSchedulerState], tuple[float, LRSchedulerState]]
[docs]
def create_cosine_scheduler(
total_steps: int,
final_lr_factor: float = 0.01,
) -> SchedulerFn:
"""
Creates a cosine learning rate scheduler.
Args:
total_steps: Total number of optimization steps
final_lr_factor: Final learning rate as a fraction of initial
"""
@jax.jit
def scheduler_fn(state: LRSchedulerState) -> tuple[float, LRSchedulerState]:
progress = jnp.minimum(state.step / total_steps, 1.0)
cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * progress))
lr = state.initial_lr * (final_lr_factor + (1 - final_lr_factor) * cosine_decay)
new_state = LRSchedulerState(
step=state.step + 1, learning_rate=lr, initial_lr=state.initial_lr
)
return lr, new_state
return scheduler_fn
[docs]
def create_step_scheduler(step_size: int, gamma: float = 0.1) -> SchedulerFn:
"""
Creates a step decay scheduler that reduces learning rate by gamma every step_size steps.
Args:
step_size: Number of steps between learning rate drops
gamma: Multiplicative factor for learning rate decay
"""
@jax.jit
def scheduler_fn(state: LRSchedulerState) -> tuple[float, LRSchedulerState]:
num_drops = state.step // step_size
lr = state.initial_lr * (gamma**num_drops)
new_state = LRSchedulerState(
step=state.step + 1, learning_rate=lr, initial_lr=state.initial_lr
)
return lr, new_state
return scheduler_fn
[docs]
def create_warmup_cosine_scheduler(
total_steps: int,
warmup_steps: int,
final_lr_factor: float = 0.01,
) -> SchedulerFn:
"""
Creates a scheduler with linear warmup followed by cosine decay.
Args:
total_steps: Total number of optimization steps
warmup_steps: Number of warmup steps
final_lr_factor: Final learning rate as a fraction of initial
"""
@jax.jit
def scheduler_fn(state: LRSchedulerState) -> tuple[float, LRSchedulerState]:
# Linear warmup
warmup_progress = jnp.minimum(state.step / warmup_steps, 1.0)
warmup_lr = state.initial_lr * warmup_progress
# Cosine decay after warmup
remaining_steps = total_steps - warmup_steps
decay_progress = jnp.maximum(0.0, state.step - warmup_steps) / remaining_steps
decay_progress = jnp.minimum(decay_progress, 1.0)
cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * decay_progress))
decay_lr = state.initial_lr * (
final_lr_factor + (1 - final_lr_factor) * cosine_decay
)
# Choose between warmup and decay
lr = jnp.where(state.step < warmup_steps, warmup_lr, decay_lr)
new_state = LRSchedulerState(
step=state.step + 1, learning_rate=lr, initial_lr=state.initial_lr
)
return lr, new_state
return scheduler_fn
[docs]
def init_scheduler_state(initial_lr: float) -> LRSchedulerState:
"""Initialize scheduler state with given learning rate."""
return LRSchedulerState(step=0, learning_rate=initial_lr, initial_lr=initial_lr)
[docs]
class OptimizerState(NamedTuple):
m: Array # First moment estimate
v: Array # Second moment estimate
step: Array # Step count
[docs]
class Optimizer(NamedTuple):
init: Callable
update: Callable
[docs]
def wirtinger_grad(
f: Callable[..., Float[Array, "..."]], argnums: Union[int, Sequence[int]] = 0
) -> Callable[..., Union[Complex[Array, "..."], Tuple[Complex[Array, "..."], ...]]]:
"""
Compute the Wirtinger gradient of a complex-valued function.
This function returns a new function that computes the Wirtinger gradient
of the input function f with respect to the specified argument(s).
Args:
- f (Callable[..., Float[Array, "..."]]):
A complex-valued function to differentiate.
- argnums (Union[int, Sequence[int]]):
Specifies which argument(s) to compute the gradient with respect to.
Can be an int or a sequence of ints. Default is 0.
Returns:
- grad_f (Callable[..., Union[Complex[Array, "..."], Tuple[Complex[Array, "..."], ...]]]):
A function that computes the Wirtinger gradient of f with respect to the specified argument(s).
"""
def grad_f(
*args: Any,
) -> Union[Complex[Array, "..."], Tuple[Complex[Array, "..."], ...]]:
def split_complex(args):
return tuple(
jnp.real(arg) if jnp.iscomplexobj(arg) else arg for arg in args
) + tuple(
jnp.imag(arg) if jnp.iscomplexobj(arg) else jnp.zeros_like(arg)
for arg in args
)
def combine_complex(r, i):
return tuple(
rr + 1j * ii if jnp.iscomplexobj(arg) else rr
for rr, ii, arg in zip(r, i, args)
)
split_args = split_complex(args)
n = len(args)
def f_real(*split_args):
return jnp.real(f(*combine_complex(split_args[:n], split_args[n:])))
def f_imag(*split_args):
return jnp.imag(f(*combine_complex(split_args[:n], split_args[n:])))
gr = jax.grad(f_real, argnums=argnums)(*split_args)
gi = jax.grad(f_imag, argnums=argnums)(*split_args)
if isinstance(argnums, int):
return gr + 1j * gi
else:
return tuple(grr + 1j * gii for grr, gii in zip(gr, gi))
return grad_f
[docs]
def complex_adam(
params: Complex[Array, "..."],
grads: Complex[Array, "..."],
state: Tuple[Complex[Array, "..."], Complex[Array, "..."], int],
learning_rate: float = 0.001,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-8,
) -> Tuple[
Complex[Array, "..."], Tuple[Complex[Array, "..."], Complex[Array, "..."], int]
]:
"""
Complex-valued Adam optimizer based on Wirtinger derivatives.
This function performs one step of the Adam optimization algorithm
for complex-valued parameters.
Args:
- params (Complex[Array, "..."]):
Current complex-valued parameters.
- grads (Complex[Array, "..."]):
Complex-valued gradients.
- state (Tuple[Complex[Array, "..."], Complex[Array, "..."], int]):
Optimizer state (first moment, second moment, timestep).
- learning_rate (float):
Learning rate (default: 0.001).
- beta1 (float):
Exponential decay rate for first moment estimates (default: 0.9).
- beta2 (float):
Exponential decay rate for second moment estimates (default: 0.999).
- eps (float):
Small value to avoid division by zero (default: 1e-8).
Returns:
- new_params (Complex[Array, "..."]):
Updated parameters.
- new_state (Tuple[Complex[Array, "..."], Complex[Array, "..."], int]):
Updated optimizer state.
"""
m, v, t = state
t += 1
m = beta1 * m + (1 - beta1) * grads
v = beta2 * v + (1 - beta2) * jnp.abs(grads) ** 2
m_hat = m / (1 - beta1**t)
v_hat = v / (1 - beta2**t)
update = learning_rate * m_hat / (jnp.sqrt(v_hat) + eps)
new_params = params - update
return new_params, (m, v, t)
[docs]
def complex_adagrad(
params: Complex[Array, "..."],
grads: Complex[Array, "..."],
state: Complex[Array, "..."],
learning_rate: float = 0.01,
eps: float = 1e-8,
) -> Tuple[Complex[Array, "..."], Complex[Array, "..."]]:
"""
Complex-valued Adagrad optimizer based on Wirtinger derivatives.
This function performs one step of the Adagrad optimization algorithm
for complex-valued parameters.
Args:
- params (Complex[Array, "..."]):
Current complex-valued parameters.
- grads (Complex[Array, "..."]):
Complex-valued gradients.
- state (Complex[Array, "..."]):
Optimizer state (accumulated squared gradients).
- learning_rate (float):
Learning rate (default: 0.01).
- eps (float):
Small value to avoid division by zero (default: 1e-8).
Returns:
- new_params (Complex[Array, "..."]): Updated parameters.
- new_state (Complex[Array, "..."]): Updated optimizer state.
"""
accumulated_grads = state
# Update accumulated squared gradients
new_accumulated_grads = accumulated_grads + jnp.abs(grads) ** 2
# Compute update
update = learning_rate * grads / (jnp.sqrt(new_accumulated_grads) + eps)
# Update parameters
new_params = params - update
return new_params, new_accumulated_grads
[docs]
def complex_rmsprop(
params: Complex[Array, "..."],
grads: Complex[Array, "..."],
state: Complex[Array, "..."],
learning_rate: float = 0.001,
decay_rate: float = 0.9,
eps: float = 1e-8,
) -> Tuple[Complex[Array, "..."], Complex[Array, "..."]]:
"""
Complex-valued RMSprop optimizer based on Wirtinger derivatives.
This function performs one step of the RMSprop optimization algorithm
for complex-valued parameters.
Args:
- params (Complex[Array, "..."]):
Current complex-valued parameters.
- grads (Complex[Array, "..."]):
Complex-valued gradients.
- state (Complex[Array, "..."]):
Optimizer state (moving average of squared gradients).
- learning_rate (float):
Learning rate (default: 0.001).
- decay_rate (float):
Decay rate for moving average (default: 0.9).
- eps (float):
Small value to avoid division by zero (default: 1e-8).
Returns:
- new_params (Complex[Array, "..."]): Updated parameters.
- new_state (Complex[Array, "..."]): Updated optimizer state.
"""
moving_avg_squared_grads = state
# Update moving average of squared gradients
new_moving_avg_squared_grads = (
decay_rate * moving_avg_squared_grads + (1 - decay_rate) * jnp.abs(grads) ** 2
)
# Compute update
update = learning_rate * grads / (jnp.sqrt(new_moving_avg_squared_grads) + eps)
# Update parameters
new_params = params - update
return new_params, new_moving_avg_squared_grads
[docs]
def init_adam(shape: tuple) -> OptimizerState:
return OptimizerState(jnp.zeros(shape), jnp.zeros(shape), jnp.array(0))
[docs]
def init_adagrad(shape: tuple) -> OptimizerState:
return OptimizerState(jnp.zeros(shape), jnp.zeros(shape), jnp.array(0))
[docs]
def init_rmsprop(shape: tuple) -> OptimizerState:
return OptimizerState(None, jnp.zeros(shape), jnp.array(0))
[docs]
def adam_update(
params: Complex[Array, "..."],
grads: Complex[Array, "..."],
state: OptimizerState,
learning_rate: float = 0.001,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-8,
) -> tuple[Complex[Array, "..."], OptimizerState]:
m, v, step = state
step += 1
m = beta1 * m + (1 - beta1) * grads
v = beta2 * v + (1 - beta2) * jnp.abs(grads) ** 2
m_hat = m / (1 - beta1**step)
v_hat = v / (1 - beta2**step)
update = learning_rate * m_hat / (jnp.sqrt(v_hat) + eps)
new_params = params - update
return new_params, OptimizerState(m, v, step)
[docs]
def adagrad_update(
params: Complex[Array, "..."],
grads: Complex[Array, "..."],
state: OptimizerState,
learning_rate: float = 0.01,
eps: float = 1e-8,
) -> tuple[Complex[Array, "..."], OptimizerState]:
_, v, step = state
step += 1
v += jnp.abs(grads) ** 2
update = learning_rate * grads / (jnp.sqrt(v) + eps)
new_params = params - update
return new_params, OptimizerState(None, v, step)
[docs]
def rmsprop_update(
params: Complex[Array, "..."],
grads: Complex[Array, "..."],
state: OptimizerState,
learning_rate: float = 0.001,
decay_rate: float = 0.9,
eps: float = 1e-8,
) -> tuple[Complex[Array, "..."], OptimizerState]:
_, v, step = state
step += 1
v = decay_rate * v + (1 - decay_rate) * jnp.abs(grads) ** 2
update = learning_rate * grads / (jnp.sqrt(v) + eps)
new_params = params - update
return new_params, OptimizerState(None, v, step)