jax>=0.4.1
jaxlib>=0.4.1
optax
jaxutils>=0.0.6
jaxkern>=0.0.4
distrax>=0.1.2
tqdm>=4.0.0
ml-collections==0.1.0
jaxtyping>=0.0.2
jaxlinop>=0.0.3
deprecation

[cuda]
jax[cuda]

[dev]
black
isort
pylint
flake8
pytest
networkx
pytest-cov
