2018-05-04 15:19:31 +01:00
{ buildPythonPackage ,
cudaSupport ? false , cudatoolkit ? null , cudnn ? null ,
fetchFromGitHub , fetchpatch , lib , numpy , pyyaml , cffi , cmake ,
git , stdenv , linkFarm , symlinkJoin ,
utillinux , which } :
2017-07-16 20:15:05 +01:00
2018-05-04 15:19:31 +01:00
assert cudnn == null || cudatoolkit != null ;
assert ! cudaSupport || cudatoolkit != null ;
let
cudatoolkit_joined = symlinkJoin {
name = " ${ cudatoolkit . name } - u n s p l i t " ;
paths = [ cudatoolkit . out cudatoolkit . lib ] ;
} ;
# Normally libcuda.so.1 is provided at runtime by nvidia-x11 via
# LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub
# libcuda.so from cudatoolkit for running tests, so that we don’ t have
# to recompile pytorch on every update to nvidia-x11 or the kernel.
cudaStub = linkFarm " c u d a - s t u b " [ {
name = " l i b c u d a . s o . 1 " ;
path = " ${ cudatoolkit } / l i b / s t u b s / l i b c u d a . s o " ;
} ] ;
cudaStubEnv = lib . optionalString cudaSupport
" L D _ L I B R A R Y _ P A T H = ${ cudaStub } \$ { L D _ L I B R A R Y _ P A T H : + : $ L D _ L I B R A R Y _ P A T H } " ;
in buildPythonPackage rec {
version = " 0 . 3 . 1 " ;
2017-07-16 20:15:05 +01:00
pname = " p y t o r c h " ;
name = " ${ pname } - ${ version } " ;
src = fetchFromGitHub {
2017-08-19 02:22:23 +01:00
owner = " p y t o r c h " ;
repo = " p y t o r c h " ;
rev = " v ${ version } " ;
2018-05-04 15:19:31 +01:00
fetchSubmodules = true ;
sha256 = " 1 k 8 f r 9 7 v 5 p f 7 r n i 5 c r 2 p i 2 1 i x c 3 p d j 3 h 3 l k z 2 8 n j b j b g k n d h 7 m r 3 " ;
2017-07-16 20:15:05 +01:00
} ;
2018-05-04 15:19:31 +01:00
patches = [
( fetchpatch {
# make sure stdatomic.h is included when checking for ATOMIC_INT_LOCK_FREE
# Fixes this test failure:
# RuntimeError: refcounted file mapping not supported on your system at /tmp/nix-build-python3.6-pytorch-0.3.0.drv-0/source/torch/lib/TH/THAllocator.c:525
url = " h t t p s : / / g i t h u b . c o m / p y t o r c h / p y t o r c h / c o m m i t / 5 0 2 a a f 3 9 c f 4 a 8 7 8 f 9 e 4 f 8 4 9 e 5 f 4 0 9 5 7 3 a a 5 9 8 a a 9 . p a t c h " ;
stripLen = 3 ;
extraPrefix = " t o r c h / l i b / " ;
sha256 = " 1 m i z 4 l h y 3 r a z j w c m h x q a 4 x m l c m h m 6 5 l q y i n 1 c z q c z j 8 g 1 6 d 3 f 6 2 f " ;
} )
] ;
postPatch = ''
substituteInPlace test/run_test.sh - - replace \
" I N I T _ M E T H O D = ' f i l e : / / ' \$ T E M P _ D I R ' / s h a r e d _ i n i t _ f i l e ' \$ P Y C M D . / t e s t _ d i s t r i b u t e d . p y " \
" e c h o S k i p p e d f o r N i x p a c k a g e "
'' ;
preConfigure = lib . optionalString cudaSupport ''
export CC = $ { cudatoolkit . cc } /bin/gcc
'' + l i b . o p t i o n a l S t r i n g ( c u d a S u p p o r t & & c u d n n ! = n u l l ) ''
export CUDNN_INCLUDE_DIR = $ { cudnn } /include
2017-07-16 20:15:05 +01:00
'' ;
buildInputs = [
cmake
git
numpy . blas
2018-05-04 15:19:31 +01:00
utillinux
which
] ++ lib . optionals cudaSupport [ cudatoolkit_joined cudnn ] ;
2017-11-22 22:02:34 +00:00
2017-07-16 20:15:05 +01:00
propagatedBuildInputs = [
cffi
numpy
pyyaml
] ;
2018-05-04 15:19:31 +01:00
checkPhase = ''
$ { cudaStubEnv } $ { stdenv . shell } test/run_test.sh
2017-07-16 20:15:05 +01:00
'' ;
2017-11-22 22:02:34 +00:00
2017-07-16 20:15:05 +01:00
meta = {
description = " T e n s o r s a n d D y n a m i c n e u r a l n e t w o r k s i n P y t h o n w i t h s t r o n g G P U a c c e l e r a t i o n . " ;
homepage = http://pytorch.org/ ;
license = lib . licenses . bsd3 ;
platforms = lib . platforms . linux ;
maintainers = with lib . maintainers ; [ teh ] ;
} ;
}