numpy>=1.20

[jax]
jax>=0.3.4
jaxlib>=0.3.2
chex>=0.1
