jax>=0.2.20
jaxlib>=0.1.71
jax_dataclasses>=1.0.0
numpy
overrides!=4

[testing]
flax
hypothesis
hypothesis[numpy]
pytest
pytest-cov
