ml_dtypes>=0.1.0
numpy>=1.21
opt_einsum
scipy>=1.7

[:python_version < "3.10"]
importlib_metadata>=4.6

[australis]
protobuf<4,>=3.13

[ci]
jaxlib==0.4.11

[cpu]
jaxlib==0.4.12

[cuda]
jaxlib==0.4.12+cuda11.cudnn86

[cuda11_cudnn82]
jaxlib==0.4.12+cuda11.cudnn82

[cuda11_cudnn86]
jaxlib==0.4.12+cuda11.cudnn86

[cuda11_local]
jaxlib==0.4.12+cuda11.cudnn86

[cuda11_pip]
jaxlib==0.4.12+cuda11.cudnn86
nvidia-cublas-cu11>=11.11
nvidia-cuda-cupti-cu11>=11.8
nvidia-cuda-nvcc-cu11>=11.8
nvidia-cuda-runtime-cu11>=11.8
nvidia-cudnn-cu11>=8.8
nvidia-cufft-cu11>=10.9
nvidia-cusolver-cu11>=11.4
nvidia-cusparse-cu11>=11.7

[cuda12_local]
jaxlib==0.4.12+cuda12.cudnn88

[cuda12_pip]
jaxlib==0.4.12+cuda12.cudnn88
nvidia-cublas-cu12
nvidia-cuda-cupti-cu12
nvidia-cuda-nvcc-cu12
nvidia-cuda-runtime-cu12
nvidia-cudnn-cu12>=8.9
nvidia-cufft-cu12
nvidia-cusolver-cu12
nvidia-cusparse-cu12

[minimum-jaxlib]
jaxlib==0.4.11

[tpu]
jaxlib==0.4.12
libtpu-nightly==0.1.dev20230608
