jax>=0.4.1
jaxlib>=0.4.1
simple-pytree==0.1.6

[dev]
pytest
pre-commit
pytest-cov
flax
