Source code for jaxtyping._decorator

# Copyright (c) 2022 Google LLC
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import dataclasses
import functools as ft
import importlib.util
import inspect
import itertools as it
import sys
import warnings
import weakref
from collections.abc import Callable
from typing import (
    Any,
    get_args,
    get_origin,
    get_type_hints,
    NoReturn,
    overload,
    ParamSpec,
    TypeVar,
    Union,
)

from jaxtyping import AbstractArray

from ._config import config
from ._errors import AnnotationError, TypeCheckError
from ._storage import pop_shape_memo, push_shape_memo, shape_str


_Params = ParamSpec("_Params")
_Return = TypeVar("_Return")
_T = TypeVar("_T")
# Not `TypeVar(..., type, Callable)` as else the output type of our first overload is
# just `type`, and not the particular class that is decorated.
_TypeOrCallable = TypeVar("_TypeOrCallable", bound=Union[type, Callable])


class _Sentinel:
    def __repr__(self):
        return "sentinel"


_sentinel = _Sentinel()
_tb_flag = True


def _apply_typechecker(typechecker, fn):
    """Calls `typechecker(fn)` in an isolated frame, returning the result.

    This avoids reference cycles that can otherwise occur if `typechecker` grabs
    the calling frame's locals.
    """
    return typechecker(fn)


@overload
def jaxtyped(
    *,
    typechecker=_sentinel,
) -> Callable[[_TypeOrCallable], _TypeOrCallable]: ...


@overload
def jaxtyped(fn: type[_T], *, typechecker=_sentinel) -> type[_T]: ...


@overload
def jaxtyped(
    fn: Callable[_Params, _Return], *, typechecker=_sentinel
) -> Callable[_Params, _Return]: ...


[docs] def jaxtyped(fn=_sentinel, *, typechecker=_sentinel): """Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes. !!! Example ```python # Import both the annotation and the `jaxtyped` decorator from `jaxtyping` from jaxtyping import Array, Float, jaxtyped # Use your favourite typechecker: usually one of the two lines below. from typeguard import typechecked as typechecker from beartype import beartype as typechecker # Type-check a function @jaxtyped(typechecker=typechecker) def batch_outer_product(x: Float[Array, "b c1"], y: Float[Array, "b c2"] ) -> Float[Array, "b c1 c2"]: return x[:, :, None] * y[:, None, :] # Type-check a dataclass from dataclasses import dataclass @jaxtyped(typechecker=typechecker) @dataclass class MyDataclass: x: int y: Float[Array, "b c"] ``` **Arguments:** - `fn`: The function or dataclass to decorate. In practice if you want to use dataclasses with JAX, then [`equinox.Module`](https://docs.kidger.site/equinox/api/module/module/) is our recommended approach: ```python import equinox as eqx @jaxtyped(typechecker=typechecker) class MyModule(eqx.Module): ... ``` - `typechecker`: Keyword-only argument: the runtime type-checker to use. This should be a function decorator that will raise an exception if there is a type error, e.g. ```python @typechecker def f(x: int): pass f("a string is not an integer") # this line should raise an exception ``` Common choices are `typechecker=beartype.beartype` or `typechecker=typeguard.typechecked`. Can also be set as `typechecker=None` to skip automatic runtime type-checking, but still support manual `isinstance` checks inside the function body: ```python @jaxtyped(typechecker=None) def f(x): assert isinstance(x, Float[Array, "batch channel"]) ``` **Returns:** If `fn` is a function (including a `staticmethod`, `classmethod`, or `property`), then a wrapped function is returned. If `fn` is a dataclass, then `fn` is returned directly, and additionally its `__init__` method is wrapped and modified in-place. !!! Info "Old syntax" jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax: ```python @jaxtyped @typechecker def f(...): ... ``` This is still supported, but will now raise a warning recommending the `jaxtyped(typechecker=typechecker)` syntax discussed above. (Which will produce easier-to-debug error messages: under the hood, the new syntax more carefully manipulates the typechecker so as to determine where a type-check error arises.) ??? Info "Notes for advanced users" **Dynamic contexts:** Put precisely, the axis names in e.g. `Float[Array, "batch channels"]` and the structure names in e.g. `PyTree[int, "T"]` are all scoped to the thread-local dynamic context of a `jaxtyped`-wrapped function. If from within that function we then call another `jaxtyped`-wrapped function, then a new context is pushed to the stack. The axis sizes and PyTree structures of this inner function will then not be compared against the axis sizes and PyTree structures of the outer function. After the inner function returns then this inner context is popped from the stack, and the previous context is returned to. **isinstance:** Binding of a value against a name is done with an `isinstance` check, for example `isinstance(jnp.zeros((3, 4)), Float[Array, "dim1 dim2"])` will bind `dim1=3` and `dim2=4`. In practice these `isinstance` checks are usually done by the run-time typechecker `typechecker` that is supplied as an argument. This can also be done manually: add `isinstance` checks inside a function body and they will contribute to the same collection of consistency checks as are performed by the typechecker on the arguments and return values. (Or you can forgo such a typechecker altogether -- i.e. `typechecker=None` -- and only do your own manual `isinstance` checks.) Only `isinstance` checks that pass will contribute to the store of values; those that fail will not. As such it is safe to write e.g. `assert not isinstance(x, Float32[Array, "foo"])`. **Decoupling contexts from function calls:** If you would like to call a new function *without* creating a new dynamic context (and using the same set of axis and structure values), then simply do not add a `jaxtyped` decorator to your inner function, whilst continuing to perform type-checking in whatever way you prefer. Conversely, if you would like a new dynamic context *without* calling a new function, then in addition to the usage discussed above, `jaxtyped` also supports being used as a context manager, by passing it the string `"context"`: ```python with jaxtyped("context"): assert isinstance(x, Float[Array, "batch channel"]) ``` This is equivalent to placing this code inside a new function wrapped in `jaxtyped(typechecker=None)`. Usage like this is very rare; it's mostly only useful when working at the global scope. """ global _tb_flag if ( _tb_flag and importlib.util.find_spec("jax") is not None and importlib.util.find_spec("jaxlib") is not None and importlib.util.find_spec("jax._src.traceback_util") is not None ): import jax._src.traceback_util as traceback_util traceback_util.register_exclusion(__file__) _tb_flag = False # First handle the `jaxtyped("context")` usage, which is a special case. if fn == "context": if typechecker is not _sentinel: raise ValueError( "Cannot use `jaxtyped` as a context with a typechecker. That is, " "`with jaxtyped('context', typechecker=...):`. is not allowed. In this " "case the type checker does not actually do anything, as there is no " "function to type-check." ) return _JaxtypingContext() # Now check that a typechecker has been explicitly declared. (Or explicitly declared # as not being used, via `typechecker=None`.) # This is needed just for backward compatibility: an undeclared typechecker # corresponds to the old double-decorator syntax. if typechecker is _sentinel: # This branch will also catch the easy-to-make mistake of # ```python # @jaxtyped(typechecker) # def foo(...): # ``` # which is a bug as `typechecker` is interpreted as the function to decorate! warnings.warn( "As of jaxtyping version 0.2.24, jaxtyping now prefers the syntax\n" "```\n" "from jaxtyping import jaxtyped\n" "# Use your favourite typechecker: usually one of the two lines below.\n" "from typeguard import typechecked as typechecker\n" "from beartype import beartype as typechecker\n" "\n" "@jaxtyped(typechecker=typechecker)\n" "def foo(...):\n" "```\n" "and the old double-decorator syntax\n" "```\n" "@jaxtyped\n" "@typechecker\n" "def foo(...):\n" "```\n" "should no longer be used. (It will continue to work as it did before, but " "the new approach will produce more readable error messages.)\n" "In particular note that `typechecker` must be passed via keyword " "argument; the following is not valid:\n" "```\n" "@jaxtyped(typechecker)\n" "def foo(...):\n" "```\n", stacklevel=2, ) typechecker = None if fn is _sentinel: return ft.partial(jaxtyped, typechecker=typechecker) elif inspect.isclass(fn): if dataclasses.is_dataclass(fn) and typechecker is not None: # This does not check that the arguments passed to `__init__` match the # type annotations. There may be a custom user `__init__`, or a # dataclass-generated `__init__` used alongside # `equinox.field(converter=...)` init = fn.__init__ @ft.wraps(init) def __init__(self, *args, **kwargs): __tracebackhide__ = True init(self, *args, **kwargs) # `fn.__init__` is late-binding to the `__init__` function that # we're in now. (Or to someone else's monkey-patch.) Either way, # this checks that we're in the "top-level" `__init__`, and not one # that is being called via `super()`. We don't want to trigger too # early, before all fields have been assigned. # # We're not checking `if self.__class__ is fn` because Equinox # replaces the with a defrozen version of itself during `__init__`, # so the check wouldn't trigger. # # We're not doing this check by adding it to the end of the # metaclass `__call__`, because Python doesn't allow you # monkey-patch metaclasses. if self.__class__.__init__ is fn.__init__: _check_dataclass_annotations(self, typechecker) fn.__init__ = __init__ return fn # It'd be lovely if we could handle arbitrary descriptors, and not just the builtin # ones. Unfortunately that means returning a class instance with a __get__ method, # and that turns out to break loads of other things. See beartype issue #211 and # jaxtyping issue #71. elif isinstance(fn, classmethod): return classmethod(jaxtyped(fn.__func__, typechecker=typechecker)) elif isinstance(fn, staticmethod): return staticmethod(jaxtyped(fn.__func__, typechecker=typechecker)) elif isinstance(fn, property): if fn.fget is None: fget = None else: fget = jaxtyped(fn.fget, typechecker=typechecker) if fn.fset is None: fset = None else: fset = jaxtyped(fn.fset, typechecker=typechecker) if fn.fdel is None: fdel = None else: fdel = jaxtyped(fn.fdel, typechecker=typechecker) return property(fget=fget, fset=fset, fdel=fdel) else: if typechecker is None: # Probably being used in the old style as # ``` # @jaxtyped # @typechecker # def foo(x: int): ... # ``` # in which case make a best-effort attempt to add shape information for any # type errors. # we want to detect generators, and ignore return annotations on them, # to avoid issues with O(n) typechecking trying to typecheck yielded values wrp = fn while hasattr(wrp, "__wrapped__"): wrp = wrp.__wrapped__ if inspect.isgeneratorfunction(wrp) or inspect.isasyncgenfunction(wrp): # recursively parse all the annotations, and mark all the jaxtyping # annotations as not needing instance checks, while still being # visible as original ones for the typechecker def modify_annotation(ann): if inspect.isclass(ann) and issubclass(ann, AbstractArray): ann.make_transparent() for sub_ann in get_args(ann): modify_annotation(sub_ann) # just to make sure: check that fn has valid return annotations if hasattr(fn, "__annotations__") and "return" in fn.__annotations__: modify_annotation(fn.__annotations__["return"]) signature = inspect.signature(fn) @ft.wraps(fn) def wrapped_fn(*args, **kwargs): # pyright: ignore __tracebackhide__ = True bound = signature.bind(*args, **kwargs) bound.apply_defaults() memos = push_shape_memo(bound.arguments) try: return fn(*args, **kwargs) except Exception as e: # add_note api is support from python 3.11+ if sys.version_info >= (3, 11) and _no_jaxtyping_note(e): shape_info = shape_str(memos) if shape_info != "": msg = ( "The preceding error occurred within the scope of a " "`jaxtyping.jaxtyped` function, and may be due to a " "typecheck error. " ) e.add_note(_jaxtyping_note_str(_spacer + msg + shape_info)) raise finally: pop_shape_memo() else: # New-style # ``` # @jaxtyped(typechecker=typechecker) # def foo(x: int): ... # ``` # in which case we can do a better job reporting errors. full_signature = inspect.signature(fn) try: destring_annotations = get_type_hints(fn, include_extras=True) except NameError: # Best-effort attempt to destringify annotations. pass else: new_params = [] for p_name, p_value in full_signature.parameters.items(): p_annotation = destring_annotations.get(p_name, p_value.annotation) p_value = p_value.replace(annotation=p_annotation) new_params.append(p_value) return_annotation = destring_annotations.get( "return", full_signature.return_annotation ) full_signature = full_signature.replace( parameters=new_params, return_annotation=return_annotation ) param_signature = full_signature.replace(return_annotation=Any) name = getattr(fn, "__name__", "<no name found>") qualname = getattr(fn, "__qualname__", "<no qualname found>") module = getattr(fn, "__module__", "<generated_by_jaxtyping>") # Use the same name so that typeguard warnings look correct. full_fn, output_name = _make_fn_with_signature( name, qualname, module, full_signature, output=True ) param_fn = _make_fn_with_signature( name, qualname, module, param_signature, output=False ) full_fn = _apply_typechecker(typechecker, full_fn) param_fn = _apply_typechecker(typechecker, param_fn) def wrapped_fn_impl(args, kwargs, bound, memos): __tracebackhide__ = True # First type-check just the parameters before the function is # called. try: param_fn(*args, **kwargs) except AnnotationError: raise except Exception: try: argmsg = _get_problem_arg( param_signature, args, kwargs, bound.arguments, module, typechecker, ) except TypeCheckError as e: argmsg = str(e) try: module_name = fn.__module__ qualname = fn.__qualname__ except AttributeError: module_name = fn.__class__.__module__ qualname = fn.__class__.__qualname__ param_values = _pformat(bound.arguments, short_self=True) param_hints = _remove_typing(param_signature) msg = ( "Type-check error whilst checking the parameters of " f"{module_name}.{qualname}.{argmsg}\n" "----------------------\n" f"Called with parameters: {param_values}\n" f"Parameter annotations: {param_hints}.\n" + shape_str(memos) ) if config.jaxtyping_remove_typechecker_stack: raise TypeCheckError(msg) from None else: raise TypeCheckError(msg) from e # Actually call the function. out = fn(*args, **kwargs) if full_signature.return_annotation is not inspect.Signature.empty: # Now type-check the return value. We need to include the # parameters in the type-checking here in case there are any # type variables shared across the parameters and return. # # Incidentally this does mean that if `fn` mutates its arguments # so that they no longer satisfy their type annotations, this # will throw an error here. But that's like, super weird, so # don't do that. An error in that scenario is probably still # desirable. # # There is a small performance concern here when used in # non-jit'd contexts, like PyTorch, due to the duplicate # checking of the parameters. Unfortunately there doesn't seem # to be a way around that, so c'est la vie. kwargs[output_name] = out try: full_fn(*args, **kwargs) except AnnotationError: raise except Exception as e: try: module_name = fn.__module__ qualname = fn.__qualname__ except AttributeError: module_name = fn.__class__.__module__ qualname = fn.__class__.__qualname__ param_values = _pformat(bound.arguments, short_self=True) return_value = _pformat(out, short_self=False) param_hints = _remove_typing(param_signature) return_hint = _remove_typing(full_signature.return_annotation) if return_hint.startswith("<class '") and return_hint.endswith( "'>" ): return_hint = return_hint[8:-2] msg = ( "Type-check error whilst checking the return value " f"of {module_name}.{qualname}.\n" f"Actual value: {return_value}\n" f"Expected type: {return_hint}.\n" "----------------------\n" f"Called with parameters: {param_values}\n" f"Parameter annotations: {param_hints}.\n" + shape_str(memos) ) if config.jaxtyping_remove_typechecker_stack: raise TypeCheckError(msg) from None else: raise TypeCheckError(msg) from e return out wrapped_fn_holder = [] # Avoids introducing a reference cycle. @ft.wraps(fn) def wrapped_fn(*args, **kwargs): __tracebackhide__ = True if ( config.jaxtyping_disable or getattr(fn, "__no_type_check__", False) or getattr(wrapped_fn_holder[0](), "__no_type_check__", False) ): return fn(*args, **kwargs) # Raise bind-time errors before we do any shape analysis. (I.e. skip # the pointless jaxtyping information for a non-typechecking failure.) bound = param_signature.bind(*args, **kwargs) bound.apply_defaults() memos = push_shape_memo(bound.arguments) try: # Put this in a separate frame to make debugging easier, without # just always ending up on the `pop_shape_memo` line below. return wrapped_fn_impl(args, kwargs, bound, memos) finally: pop_shape_memo() wrapped_fn_holder.append(weakref.ref(wrapped_fn)) return wrapped_fn
class _JaxtypingContext: def __enter__(self): push_shape_memo({}) def __exit__(self, exc_type, exc_value, exc_tb): pop_shape_memo() def _check_dataclass_annotations(self, typechecker): """Creates and calls a function that checks the attributes of `self` `self` should be a dataclass instance. `typechecker` should be e.g. `beartype.beartype` or `typeguard.typechecked`. """ parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] values = {} for field in dataclasses.fields(self): annotation = field.type if isinstance(annotation, str): # Don't check stringified annotations. These are basically impossible to # resolve correctly, so just skip them. continue if get_origin(annotation) is type: args = get_args(annotation) if len(args) == 1 and isinstance(args[0], str): # We also special-case this one kind of partially-stringified type # annotation, so as to support Equinox <v0.11.1. # This was fixed in Equinox in # https://github.com/patrick-kidger/equinox/pull/543 continue try: value = getattr(self, field.name) # noqa: F841 except AttributeError: continue # allow uninitialised fields, which are allowed on dataclasses parameters.append( inspect.Parameter( field.name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=field.type, ) ) values[field.name] = value signature = inspect.Signature(parameters) f = _make_fn_with_signature( self.__class__.__name__, self.__class__.__qualname__, self.__class__.__module__, signature, output=False, ) f = jaxtyped(f, typechecker=typechecker) f(self, **values) def _make_fn_with_signature( name: str, qualname: str, module: str, signature: inspect.Signature, output: bool ): """Dynamically creates a function `fn` with name `name` and signature `signature`. If `output=True` then `fn` will consume an additional keyword-only argument (in addition to the provided signature), and will directly return this argument. In this case the returned value from `_make_fn_with_signature` is a 2-tuple of `(fn, name)`, where `fn` is the generated function, and `name` is the name of this extra argument. If `output=False` then `fn` will just have a single `pass` statement, and the returned value from `_make_fn_with_signature` will just be `fn`. --- Note that this function operates by dynamically creating and eval'ing a string, not simply by assigning `__signature__` and `__annotations__`. The latter is enough for typeguard (at least v2), but does not work with beartype (at least v16). """ pos = [] pos_or_key = [] varpos = [] key = [] varkey = [] for p in signature.parameters.values(): if p.kind == inspect.Parameter.POSITIONAL_ONLY: pos.append(p) elif p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: pos_or_key.append(p) elif p.kind == inspect.Parameter.VAR_POSITIONAL: varpos.append(p) elif p.kind == inspect.Parameter.KEYWORD_ONLY: key.append(p) elif p.kind == inspect.Parameter.VAR_KEYWORD: varkey.append(p) else: assert False param_names = frozenset(signature.parameters.keys()) if output: output_name = _gensym(param_names, prefix="ret") outstr = "return " + output_name param_names = param_names | frozenset({output_name}) key.append(inspect.Parameter(output_name, kind=inspect.Parameter.KEYWORD_ONLY)) else: outstr = "pass" scope = {name: None} name_to_annotation = {} name_to_default = {} param_triples = ( (p.name, p.annotation, p.default) for p in signature.parameters.values() ) if output: triples = it.chain( param_triples, [ ("return", signature.return_annotation, inspect.Signature.empty), (output_name, Any, inspect.Signature.empty), ], ) else: triples = it.chain( param_triples, [("return", signature.return_annotation, inspect.Signature.empty)], ) for p_name, p_annotation, p_default in triples: annotation_name = _gensym(frozenset(scope.keys()) | param_names, prefix="T") name_to_annotation[p_name] = annotation_name if p_annotation is inspect.Signature.empty or isinstance(p_annotation, str): # If we have a stringified annotation here it's because the get_type_hints # lookup above failed. Typically this occurs when using a local variable as # the annotation. In this case we really have no idea what the annotation # refers to, so just set it to Any. # This does mean that we don't handle partially-stringified local # annotations, e.g. `type["Foo"]` for some local type `Foo`. Those will # probably just error out. Nothing better we can do about that # unfortunately. scope[annotation_name] = Any else: scope[annotation_name] = p_annotation default_name = _gensym(frozenset(scope.keys()) | param_names, prefix="default") name_to_default[p_name] = default_name scope[default_name] = p_default argstr_pieces = [] if len(pos) > 0: for p in pos: argstr_pieces.append(_make_argpiece(p, name_to_annotation, name_to_default)) argstr_pieces.append("/") if len(pos_or_key) > 0: for p in pos_or_key: argstr_pieces.append(_make_argpiece(p, name_to_annotation, name_to_default)) if len(varpos) == 1: [p] = varpos argstr_pieces.append( "*" + _make_argpiece(p, name_to_annotation, name_to_default) ) else: assert len(varpos) == 0 if len(key) > 0: argstr_pieces.append("*") if len(key) > 0: for p in key: argstr_pieces.append(_make_argpiece(p, name_to_annotation, name_to_default)) if len(varkey) == 1: [p] = varkey argstr_pieces.append( "**" + _make_argpiece(p, name_to_annotation, name_to_default) ) else: assert len(varkey) == 0 argstr = ", ".join(argstr_pieces) if signature.return_annotation is inspect.Signature.empty: retstr = "" else: retstr = f"-> {name_to_annotation['return']}" fnstr = f"def {name}({argstr}){retstr}:\n {outstr}" exec(fnstr, scope) fn = scope[name] del scope[name] # Avoids introducing a reference cycle. fn.__module__ = module fn.__qualname__ = qualname assert fn is not None if output: return fn, output_name else: return fn def _gensym(names: frozenset[str], prefix: str) -> str: assert prefix.isidentifier() output_index = 0 output_name = prefix + str(output_index) while output_name in names: output_index += 1 output_name = prefix + str(output_index) assert output_name.isidentifier() return output_name def _make_argpiece(p, name_to_annotation, name_to_default): if p.default is inspect.Signature.empty: return f"{p.name}: {name_to_annotation[p.name]}" else: return f"{p.name}: {name_to_annotation[p.name]} = {name_to_default[p.name]}" def _get_problem_arg( param_signature: inspect.Signature, args, kwargs, arguments, module, typechecker ) -> NoReturn: """Determines which argument was likely to be the problematic one responsible for raising a type-check error. It returns the result by raising an Exception: you should grab this and extract the string out of it. We do this, rather than just returning the value, to aid debugging by making it possible to walk down the stack to the issue. """ # No performance concerns, as this is only used when we're about to raise an error # anyway. for keep_name in param_signature.parameters.keys(): new_parameters = [] keep_annotation = sentinel = object() for p_name, p in param_signature.parameters.items(): if p_name == keep_name: new_parameters.append( inspect.Parameter( p.name, p.kind, default=p.default, annotation=p.annotation ) ) assert keep_annotation is sentinel keep_annotation = _remove_typing(p.annotation) else: new_parameters.append( inspect.Parameter(p.name, p.kind, default=p.default) ) assert keep_annotation is not sentinel new_signature = inspect.Signature(new_parameters) fn = _make_fn_with_signature( "check_single_arg", "check_single_arg", module, new_signature, output=False ) fn = _apply_typechecker( typechecker, fn ) # but no `jaxtyped`; keep the same environment. try: fn(*args, **kwargs) except Exception as e: keep_value = _pformat(arguments[keep_name], short_self=False) raise TypeCheckError( f"\nThe problem arose whilst typechecking parameter '{keep_name}'.\n" f"Actual value: {keep_value}\n" f"Expected type: {keep_annotation}." ) from e else: # Could not localise the problem to a single argument -- probably due to # e.g. a mismatched typevar, which each individual argument is okay with. raise TypeCheckError("") def _remove_typing(x): x = str(x) x = x.replace(" jaxtyping.", " ") x = x.replace("[jaxtyping.", "[") x = x.replace("'jaxtyping.", "'") x = x.replace(" typing.", " ") x = x.replace("[typing.", "[") x = x.replace("'typing.", "'") return x def _pformat(x, short_self: bool): # No performance concerns from delayed imports -- this is only used when we're about # to raise an error anyway. try: # If we can, use `eqx.tree_pformat`, which wraps `wadler_lindig.pformat` with # understanding of a few other JAX-specific things. import equinox as eqx pformat = eqx.tree_pformat if short_self: try: self = x["self"] except KeyError: pass else: is_self = lambda y: y is self pformat = ft.partial(pformat, truncate_leaf=is_self) except Exception: # Failing that fall back to `wadler_lindig.pformat` directly. import wadler_lindig pformat = wadler_lindig.pformat if short_self: try: self = x["self"] except KeyError: pass else: def custom(obj): if obj is self: return wadler_lindig.TextDoc(f"{type(obj).__name__}(...)") pformat = ft.partial(pformat, custom=custom) try: return pformat(x) except Exception: return f"<Exception raised when pretty-formatting object of type {type(x)}.>" class _jaxtyping_note_str(str): """Used with `_no_jaxtyping_note` to flag that a note came from jaxtyping.""" def _no_jaxtyping_note(e: Exception) -> bool: """Checks if any of the exception's notes are from jaxtyping.""" try: notes = e.__notes__ except AttributeError: return True else: for note in notes: if isinstance(note, _jaxtyping_note_str): return False return True _spacer = "--------------------\n"