diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 287d2928a303..48c5cb3be6a1 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -4243,20 +4243,23 @@ in { jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { cudaSupport = pkgs.config.cudaSupport or false; - inherit (self.tensorflow) cudaPackages; + # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we + # pin to `cudaPackages_11_6` instead. + cudaPackages = pkgs.cudaPackages_11_6; }; jaxlib-build = callPackage ../development/python-modules/jaxlib { # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. cudaSupport = pkgs.config.cudaSupport or false; - inherit (self.tensorflow) cudaPackages; + # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we + # pin to `cudaPackages_11_6` instead. + cudaPackages = pkgs.cudaPackages_11_6; }; jaxlib = self.jaxlib-build; jaxlibWithCuda = self.jaxlib-build.override { cudaSupport = true; - }; jaxlibWithoutCuda = self.jaxlib-build.override {