diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix index db1914f4ee7b..e82054c2885b 100644 --- a/pkgs/development/python-modules/pytorch/default.nix +++ b/pkgs/development/python-modules/pytorch/default.nix @@ -74,27 +74,35 @@ let # (allowing FBGEMM to be built in pytorch-1.1), and may future proof this # derivation. brokenArchs = [ "3.0" ]; # this variable is only used as documentation. - cuda9ArchList = [ - "3.5" - "5.0" - "5.2" - "6.0" - "6.1" - "7.0" - "7.0+PTX" # I am getting a "undefined architecture compute_75" on cuda 9 - # which leads me to believe this is the final cuda-9-compatible architecture. - ]; - cuda10ArchList = cuda9ArchList ++ [ - "7.5" - "7.5+PTX" # < most recent architecture as of cudatoolkit_10_0 and pytorch-1.2.0 - ]; + + cudaCapabilities = rec { + cuda9 = [ + "3.5" + "5.0" + "5.2" + "6.0" + "6.1" + "7.0" + "7.0+PTX" # I am getting a "undefined architecture compute_75" on cuda 9 + # which leads me to believe this is the final cuda-9-compatible architecture. + ]; + + cuda10 = cuda9 ++ [ + "7.5" + "7.5+PTX" # < most recent architecture as of cudatoolkit_10_0 and pytorch-1.2.0 + ]; + + cuda11 = cuda10 ++ [ + "8.0" + "8.0+PTX" # < CUDA toolkit 11.0 + "8.6" + "8.6+PTX" # < CUDA toolkit 11.1 + ]; + }; final_cudaArchList = if !cudaSupport || cudaArchList != null then cudaArchList - else - if lib.versions.major cudatoolkit.version == "9" - then cuda9ArchList - else cuda10ArchList; # the assert above removes any ambiguity here. + else cudaCapabilities."cuda${lib.versions.major cudatoolkit.version}"; # Normally libcuda.so.1 is provided at runtime by nvidia-x11 via # LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub