forked from mirrors/nixpkgs
python3Packages.torchvision: added cudaSupport option (#132917)
Co-authored-by: Sandro <sandro.jaeckel@gmail.com>
This commit is contained in:
parent
0d078fcdb2
commit
717538e908
|
@ -301,6 +301,11 @@ in buildPythonPackage rec {
|
|||
# Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
|
||||
requiredSystemFeatures = [ "big-parallel" ];
|
||||
|
||||
passthru = {
|
||||
inherit cudaSupport;
|
||||
cudaArchList = final_cudaArchList;
|
||||
};
|
||||
|
||||
meta = with lib; {
|
||||
description = "Open source, prototype-to-production deep learning platform";
|
||||
homepage = "https://pytorch.org/";
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
{ lib
|
||||
, symlinkJoin
|
||||
, buildPythonPackage
|
||||
, fetchFromGitHub
|
||||
, ninja
|
||||
|
@ -10,9 +11,18 @@
|
|||
, pillow
|
||||
, pytorch
|
||||
, pytest
|
||||
, cudatoolkit
|
||||
, cudnn
|
||||
, cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
let
|
||||
cudatoolkit_joined = symlinkJoin {
|
||||
name = "${cudatoolkit.name}-unsplit";
|
||||
paths = [ cudatoolkit.out cudatoolkit.lib ];
|
||||
};
|
||||
cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList;
|
||||
in buildPythonPackage rec {
|
||||
pname = "torchvision";
|
||||
version = "0.10.0";
|
||||
|
||||
|
@ -23,15 +33,22 @@ buildPythonPackage rec {
|
|||
sha256 = "13j04ij0jmi58nhav1p69xrm8dg7jisg23268i3n6lnms37n02kc";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [ libpng ninja which ];
|
||||
nativeBuildInputs = [ libpng ninja which ]
|
||||
++ lib.optionals cudaSupport [ cudatoolkit_joined ];
|
||||
|
||||
TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/";
|
||||
TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/";
|
||||
|
||||
buildInputs = [ libjpeg_turbo libpng ];
|
||||
buildInputs = [ libjpeg_turbo libpng ]
|
||||
++ lib.optionals cudaSupport [ cudnn ];
|
||||
|
||||
propagatedBuildInputs = [ numpy pillow pytorch scipy ];
|
||||
|
||||
preBuild = lib.optionalString cudaSupport ''
|
||||
export TORCH_CUDA_ARCH_LIST="${cudaArchStr}"
|
||||
export FORCE_CUDA=1
|
||||
'';
|
||||
|
||||
# tries to download many datasets for tests
|
||||
doCheck = false;
|
||||
|
||||
|
@ -45,6 +62,7 @@ buildPythonPackage rec {
|
|||
description = "PyTorch vision library";
|
||||
homepage = "https://pytorch.org/";
|
||||
license = licenses.bsd3;
|
||||
platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin;
|
||||
maintainers = with maintainers; [ ericsagnes ];
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue