jax>=0.4.1
jaxlib>=0.4.1
jaxtyping
simple-pytree

[dev]
pytest
pre-commit
pytest-cov
