jax>=0.4.0
jaxlib>=0.4.0
numpy>=1.20.0
scipy>=1.7.0