jax>=0.4.6
chex>=0.1.7
jaxtyping>=0.2.14

[dev]
pytest>=7.1.2
twine>=4.0.2
wheel>=0.37.1
dm_env>=1.5
