From 2ddeac059b7dc2e2fa9b231426fbf70ec3ebb3c0 Mon Sep 17 00:00:00 2001 From: Samuel Ainsworth Date: Fri, 4 Mar 2022 03:04:12 +0000 Subject: [PATCH] python3Packages.jaxlib-bin: add support for python 3.10 and cudnn >=8.2 --- .../development/python-modules/jaxlib/bin.nix | 79 ++++++++++++------- .../python-modules/jaxlib/prefetch.sh | 7 ++ 2 files changed, 58 insertions(+), 28 deletions(-) create mode 100755 pkgs/development/python-modules/jaxlib/prefetch.sh diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index 3504c6bf3204..7e6b00429dfa 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -24,50 +24,73 @@ , flatbuffers , isPy39 , lib +, python , scipy , stdenv # Options: , cudaSupport ? config.cudaSupport or false }: -# Note that these values are tied to the specific version of the GPU wheel that -# we fetch. When updating, try to go for the latest possible versions that are -# still compatible with the cudatoolkit and cudnn versions available in nixpkgs. +# There are no jaxlib wheels targeting cudnn <8.0.5, and although there are +# wheels for cudatoolkit <11.1, we don't support them. assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; let - device = if cudaSupport then "gpu" else "cpu"; + version = "0.3.0"; + + pythonVersion = python.pythonVersion; + + # Find new releases at https://storage.googleapis.com/jax-releases. When + # upgrading, you can get these hashes from prefetch.sh. + cpuSrcs = { + "3.9" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; + hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q="; + }; + "3.10" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl"; + hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ="; + }; + }; + + gpuSrcs = { + "3.9-805" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; + hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw="; + }; + "3.9-82" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; + hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4="; + }; + "3.10-805" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl"; + hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U="; + }; + "3.10-82" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl"; + hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go="; + }; + }; in buildPythonPackage rec { pname = "jaxlib"; - version = "0.3.0"; + inherit version; format = "wheel"; - # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting - # all of them is a pain, so we focus on 3.9, the current nixpkgs python3 - # version. - disabled = !isPy39; + # At the time of writing (2022-03-03), there are releases for <=3.10. + # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs + # python3 version, and 3.10. + disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10"); - # Find new releases at https://storage.googleapis.com/jax-releases. - src = { - cpu = fetchurl { - url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; - sha256 = "151p4vqli8x0iqgrzrr8piqk7d76a2xq2krf23jlb142iam5bw01"; - }; - gpu = fetchurl { - # Note that there's also a release targeting cuDNN 8.2, but unfortunately - # we don't yet have that packaged at the time of writing (02/03/2022). - # Check pkgs/development/libraries/science/math/cudnn/default.nix for more - # details. - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; - sha256 = "0z15rdw3a8sq51rpjmfc41ix1q095aasl79rvlib85ir6f3wh2h8"; - - # This is what the cuDNN 8.2 download looks like for future reference: - # url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; - # sha256 = "000mnm2masm3sx3haddcmgw43j4gxa3m4fcm14p9nb8dnncjkgpb"; - }; - }.${device}; + src = + if !cudaSupport then cpuSrcs."${pythonVersion}" else + let + # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and + # 8.2. Try to use 8.2 whenever possible. + cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805"; + in + gpuSrcs."${pythonVersion}-${cudnnVersion}"; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. diff --git a/pkgs/development/python-modules/jaxlib/prefetch.sh b/pkgs/development/python-modules/jaxlib/prefetch.sh new file mode 100755 index 000000000000..31db6530639f --- /dev/null +++ b/pkgs/development/python-modules/jaxlib/prefetch.sh @@ -0,0 +1,7 @@ +version="$1" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"