jax>=0.1.67
jaxlib>=0.1.47

[cuda]
jax[cuda]

[dev]
black
isort
pylint
flake8
pytest
pytest-cov
