import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import NamedTuple, Optional, Tuple
from jax.tree_util import register_pytree_node_class
from jaxtyping import Array, Bool, Complex, Float, jaxtyped
import ptyrodactyl.optics as pto
jax.config.update("jax_enable_x64", True)
[docs]
@register_pytree_node_class
class LensParams(NamedTuple):
"""
Description
-----------
PyTree structure for lens parameters
Attributes
----------
- `focal_length` (Float[Array, ""]):
Focal length of the lens in meters
- `diameter` (Float[Array, ""]):
Diameter of the lens in meters
- `n` (Float[Array, ""]):
Refractive index of the lens material
- `center_thickness` (Float[Array, ""]):
Thickness at the center of the lens in meters
- `R1` (Float[Array, ""]):
Radius of curvature of the first surface in meters (positive for convex)
- `R2` (Float[Array, ""]):
Radius of curvature of the second surface in meters (positive for convex)
Notes
-----
This class is registered as a PyTree node, making it compatible with JAX transformations
like jit, grad, and vmap. The auxiliary data in tree_flatten is None as all relevant
data is stored in JAX arrays.
"""
focal_length: Float[Array, ""]
diameter: Float[Array, ""]
n: Float[Array, ""]
center_thickness: Float[Array, ""]
R1: Float[Array, ""]
R2: Float[Array, ""]
[docs]
def tree_flatten(self):
# Return a tuple of arrays (the children) and None (the auxiliary data)
return (
(
self.focal_length,
self.diameter,
self.n,
self.center_thickness,
self.R1,
self.R2,
),
None,
)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
# Reconstruct the NamedTuple from flattened data
return cls(*children)
[docs]
@jaxtyped(typechecker=beartype)
def lens_thickness_profile(
r: Float[Array, "H W"],
R1: Float[Array, ""],
R2: Float[Array, ""],
center_thickness: Float[Array, ""],
diameter: Float[Array, ""],
) -> Float[Array, "H W"]:
"""
Description
-----------
Calculate the thickness profile of a lens.
Parameters
----------
- `r` (Float[Array, "H W"]):
Radial distance from the optical axis
- `R1` (Float[Array, ""]):
Radius of curvature of the first surface
- `R2` (Float[Array, ""]):
Radius of curvature of the second surface
- `center_thickness` (Float[Array, ""]):
Thickness at the center of the lens
- `diameter` (Float[Array, ""]):
Diameter of the lens
Returns
-------
- `thickness` (Float[Array, "H W"]):
Thickness profile of the lens
Flow
----
- Calculate surface sag for both surfaces
- Combine sags with center thickness
- Apply aperture mask
- Return thickness profile
"""
# Calculate surface sags
sag1: Float[Array, "H W"] = jnp.where(
r <= diameter / 2, R1 - jnp.sqrt(jnp.maximum(R1**2 - r**2, 0.0)), 0.0
)
sag2: Float[Array, "H W"] = jnp.where(
r <= diameter / 2, R2 - jnp.sqrt(jnp.maximum(R2**2 - r**2, 0.0)), 0.0
)
# Calculate total thickness profile
thickness: Float[Array, "H W"] = jnp.where(
r <= diameter / 2, center_thickness + sag1 - sag2, 0.0
)
return thickness
[docs]
@jaxtyped(typechecker=beartype)
def lens_focal_length(
n: Float[Array, ""],
R1: Float[Array, ""],
R2: Float[Array, ""],
) -> Float[Array, ""]:
"""
Description
-----------
Calculate the focal length of a lens using the lensmaker's equation.
Parameters
----------
- `n` (Float[Array, ""]):
Refractive index of the lens material
- `R1` (Float[Array, ""]):
Radius of curvature of the first surface (positive for convex)
- `R2` (Float[Array, ""]):
Radius of curvature of the second surface (positive for convex)
Returns
-------
- `f` (Float[Array, ""]):
Focal length of the lens
Flow
----
- Apply the lensmaker's equation
- Return the calculated focal length
"""
f: Float[Array, ""] = 1.0 / ((n - 1.0) * (1.0 / R1 - 1.0 / R2))
return f
[docs]
@jaxtyped(typechecker=beartype)
def create_lens_phase(
X: Float[Array, "H W"],
Y: Float[Array, "H W"],
params: LensParams,
wavelength: Float[Array, ""],
) -> Tuple[Float[Array, "H W"], Float[Array, "H W"]]:
"""
Description
-----------
Create the phase profile and transmission mask for a lens.
Parameters
----------
- `X` (Float[Array, "H W"]):
X coordinates grid
- `Y` (Float[Array, "H W"]):
Y coordinates grid
- `params` (LensParams):
Lens parameters
- `wavelength` (Float[Array, ""]):
Wavelength of light
Returns
-------
- `phase_profile` (Float[Array, "H W"]):
Phase profile of the lens
- `transmission` (Float[Array, "H W"]):
Transmission mask of the lens
Flow
----
- Calculate radial coordinates
- Calculate thickness profile
- Calculate phase profile
- Create transmission mask
- Return phase and transmission
"""
# Calculate radial coordinates
r: Float[Array, "H W"] = jnp.sqrt(X**2 + Y**2)
# Calculate thickness profile
thickness: Float[Array, "H W"] = pto.calculate_thickness_profile(
r, params.R1, params.R2, params.center_thickness, params.diameter
)
# Calculate phase profile
k: Float[Array, ""] = 2 * jnp.pi / wavelength
phase_profile: Float[Array, "H W"] = k * (params.n - 1) * thickness
# Create transmission mask
transmission: Float[Array, "H W"] = (r <= params.diameter / 2).astype(float)
return (phase_profile, transmission)
[docs]
@jaxtyped(typechecker=beartype)
def propagate_through_lens(
field: Complex[Array, "H W"],
phase_profile: Float[Array, "H W"],
transmission: Float[Array, "H W"],
) -> Complex[Array, "H W"]:
"""
Description
-----------
Propagate a field through a lens.
Parameters
----------
- `field` (Complex[Array, "H W"]):
Input complex field
- `phase_profile` (Float[Array, "H W"]):
Phase profile of the lens
- `transmission` (Float[Array, "H W"]):
Transmission mask of the lens
Returns
-------
- `output_field` (Complex[Array, "H W"]):
Field after passing through the lens
Flow
----
- Apply transmission mask
- Add phase profile
- Return modified field
"""
output_field: Complex[Array, "H W"] = pto.add_phase_screen(
field * transmission, phase_profile
)
return output_field
[docs]
@jaxtyped(typechecker=beartype)
def double_convex_lens(
focal_length: Float[Array, ""],
diameter: Float[Array, ""],
n: Float[Array, ""],
center_thickness: Float[Array, ""],
R_ratio: Optional[Float[Array, ""]] = jnp.array(1.0),
) -> LensParams:
"""
Description
-----------
Create parameters for a double convex lens.
Parameters
----------
- `focal_length` (Float[Array, ""]):
Desired focal length
- `diameter` (Float[Array, ""]):
Lens diameter
- `n` (Float[Array, ""]):
Refractive index
- `center_thickness` (Float[Array, ""]):
Center thickness
- `R_ratio` (Optional[Float[Array, ""]]):
Ratio of R2/R1.
default is 1.0 for symmetric lens
Returns
-------
- `params` (LensParams):
Lens parameters
Flow
----
- Calculate R1 using lensmaker's equation
- Calculate R2 using R_ratio
- Create and return LensParams
"""
# For a double convex lens, both R1 and R2 are positive
R1: Float[Array, ""] = focal_length * (n - 1) * (1 + R_ratio) / 2
R2: Float[Array, ""] = R1 * R_ratio
return LensParams(
focal_length=focal_length,
diameter=diameter,
n=n,
center_thickness=center_thickness,
R1=R1,
R2=R2,
)
[docs]
def double_concave_lens(
focal_length: Float[Array, ""],
diameter: Float[Array, ""],
n: Float[Array, ""],
center_thickness: Float[Array, ""],
R_ratio: Optional[Float[Array, ""]] = jnp.array(1.0),
) -> LensParams:
"""
Description
-----------
Create parameters for a double concave lens.
Parameters
----------
- `focal_length` (Float[Array, ""]):
Desired focal length
- `diameter` (Float[Array, ""]):
Lens diameter
- `n` (Float[Array, ""]):
Refractive index
- `center_thickness` (Float[Array, ""]):
Center thickness
- `R_ratio` (Optional[Float[Array, ""]]):
Ratio of R2/R1.
default is 1.0 for symmetric lens
Returns
-------
- `params` (LensParams):
Lens parameters
Flow
----
- Calculate R1 using lensmaker's equation
- Calculate R2 using R_ratio
- Create and return LensParams
"""
# For a double concave lens, both R1 and R2 are negative
R1: Float[Array, ""] = focal_length * (n - 1) * (1 + R_ratio) / 2
R2: Float[Array, ""] = R1 * R_ratio
return LensParams(
focal_length=focal_length,
diameter=diameter,
n=n,
center_thickness=center_thickness,
R1=-abs(R1), # Ensure negative
R2=-abs(R2), # Ensure negative
)
[docs]
@jaxtyped(typechecker=beartype)
def plano_convex_lens(
focal_length: Float[Array, ""],
diameter: Float[Array, ""],
n: Float[Array, ""],
center_thickness: Float[Array, ""],
convex_first: Optional[Bool[Array, ""]] = jnp.array(True),
) -> LensParams:
"""
Description
-----------
Create parameters for a plano-convex lens.
Parameters
----------
- `focal_length` (Float[Array, ""]):
Desired focal length
- `diameter` (Float[Array, ""]):
Lens diameter
- `n` (Float[Array, ""]):
Refractive index
- `center_thickness` (Float[Array, ""]):
Center thickness
- `R_ratio` (Optional[Float[Array, ""]]):
Ratio of R2/R1.
default is 1.0 for symmetric lens
- `convex_first` (Optional[Bool[Array, ""]]):
If True, first surface is convex.
Default: True
Returns
-------
- `params` (LensParams):
Lens parameters
Flow
----
- Calculate R for curved surface
- Set other R to infinity (flat surface)
- Create and return LensParams
"""
R: Float[Array, ""] = focal_length * (n - 1)
# Assign R to first or second surface based on convex_first
R1: Float[Array, ""] = jnp.where(convex_first, R, jnp.inf)
R2: Float[Array, ""] = jnp.where(convex_first, jnp.inf, R)
return LensParams(
focal_length=focal_length,
diameter=diameter,
n=n,
center_thickness=center_thickness,
R1=R1,
R2=R2,
)
[docs]
@jaxtyped(typechecker=beartype)
def plano_concave_lens(
focal_length: Float[Array, ""],
diameter: Float[Array, ""],
n: Float[Array, ""],
center_thickness: Float[Array, ""],
concave_first: Optional[Bool[Array, ""]] = jnp.array(True),
) -> LensParams:
"""
Description
-----------
Create parameters for a plano-concave lens.
Parameters
----------
- `focal_length` (Float[Array, ""]):
Desired focal length
- `diameter` (Float[Array, ""]):
Lens diameter
- `n` (Float[Array, ""]):
Refractive index
- `center_thickness` (Float[Array, ""]):
Center thickness
- `R_ratio` (Optional[Float[Array, ""]]):
Ratio of R2/R1.
default is 1.0 for symmetric lens
- `concave_first` (Optional[Bool[Array, ""]]):
If True, first surface is concave (default: True)
Returns
-------
- `params` (LensParams):
Lens parameters
Flow
----
- Calculate R for curved surface
- Set other R to infinity (flat surface)
- Create and return LensParams
"""
R: Float[Array, ""] = -abs(focal_length * (n - 1))
# Assign R to first or second surface based on concave_first
R1: Float[Array, ""] = jnp.where(concave_first, R, jnp.inf)
R2: Float[Array, ""] = jnp.where(concave_first, jnp.inf, R)
return LensParams(
focal_length=focal_length,
diameter=diameter,
n=n,
center_thickness=center_thickness,
R1=R1,
R2=R2,
)
[docs]
@jaxtyped(typechecker=beartype)
def meniscus_lens(
focal_length: Float[Array, ""],
diameter: Float[Array, ""],
n: Float[Array, ""],
center_thickness: Float[Array, ""],
R_ratio: Float[Array, ""],
convex_first: Optional[Bool[Array, ""]] = jnp.array(True),
) -> LensParams:
"""
Description
-----------
Create parameters for a meniscus (concavo-convex) lens.
For a meniscus lens, one surface is convex (positive R)
and one is concave (negative R).
Parameters
----------
- `focal_length` (Float[Array, ""]):
Desired focal length in meters
- `diameter` (Float[Array, ""]):
Lens diameter in meters
- `n` (Float[Array, ""]):
Refractive index of lens material
- `center_thickness` (Float[Array, ""]):
Center thickness in meters
- `R_ratio` (Float[Array, ""]):
Absolute ratio of R2/R1
- `convex_first` (Bool[Array, ""]):
If True, first surface is convex (default: True)
Returns
-------
- `params` (LensParams):
Lens parameters
Flow
----
- Calculate magnitude of R1 using lensmaker's equation
- Calculate R2 magnitude using R_ratio
- Assign correct signs based on convex_first
- Create and return LensParams
"""
# Calculate absolute values of radii
# Using lensmaker's equation: 1/f = (n-1)(1/R1 - 1/R2)
R1_mag: Float[Array, ""] = (
focal_length * (n - 1) * (1 - R_ratio) / (1 if convex_first else -1)
)
R2_mag: Float[Array, ""] = abs(R1_mag * R_ratio)
# Assign signs based on which surface is convex
R1: Float[Array, ""] = jnp.where(
convex_first,
abs(R1_mag), # Convex first surface (positive)
-abs(R1_mag), # Concave first surface (negative)
)
R2: Float[Array, ""] = jnp.where(
convex_first,
-abs(R2_mag), # Concave second surface (negative)
abs(R2_mag), # Convex second surface (positive)
)
return LensParams(
focal_length=focal_length,
diameter=diameter,
n=n,
center_thickness=center_thickness,
R1=R1,
R2=R2,
)