3
0
Fork 0
forked from mirrors/nixpkgs

python310Packages.dm-haiku: 0.0.6 -> 0.0.7

This commit is contained in:
Jonas Heinrich 2022-07-13 09:50:12 +02:00 committed by Yt
parent 1b0ff5bba2
commit 3c7534f21a
2 changed files with 89 additions and 38 deletions

View file

@ -1,66 +1,49 @@
{ buildPythonPackage
, chex
, cloudpickle
, dill
, dm-tree
, fetchFromGitHub
, jaxlib
, jmp
, callPackage
, lib
, pytest-xdist
, pytestCheckHook
, jmp
, tabulate
, tensorflow
, jaxlib
}:
buildPythonPackage rec {
pname = "dm-haiku";
version = "0.0.6";
version = "0.0.7";
src = fetchFromGitHub {
owner = "deepmind";
repo = pname;
rev = "v${version}";
hash = "sha256-qvKMeGPiWXvvyV+GZdTWdsC6Wp08AmP8nDtWk7sZtqM=";
hash = "sha256-Qa3g3vOPZJt/wBjjuZHAcFUz/gwN/yvirV/8V9CnIko=";
};
propagatedBuildInputs = [
jmp
tabulate
outputs = [
"out"
"testsout"
];
checkInputs = [
chex
cloudpickle
dill
dm-tree
propagatedBuildInputs = [
jaxlib
pytest-xdist
pytestCheckHook
tensorflow
jmp
tabulate
];
pythonImportsCheck = [
"haiku"
];
disabledTestPaths = [
# These tests require `bsuite` which isn't packaged in `nixpkgs`.
"examples/impala_lite_test.py"
"examples/impala/actor_test.py"
"examples/impala/learner_test.py"
# This test breaks on multiple cases with TF-related errors,
# likely that's the reason the upstream uses TF-nightly for tests?
# `nixpkgs` doesn't have the corresponding TF version packaged.
"haiku/_src/integration/jax2tf_test.py"
# `TypeError: lax.conv_general_dilated requires arguments to have the same dtypes, got float32, float16`.
"haiku/_src/integration/numpy_inputs_test.py"
];
postInstall = ''
mkdir $testsout
cp -R examples $testsout/examples
'';
disabledTests = [
# See https://github.com/deepmind/dm-haiku/issues/366.
"test_jit_Recurrent"
];
# check in passthru.tests.pytest to escape infinite recursion with bsuite
doCheck = false;
passthru.tests = {
pytest = callPackage ./tests.nix { };
};
meta = with lib; {
description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet.";

View file

@ -0,0 +1,68 @@
{ stdenv
, buildPythonPackage
, dm-haiku
, chex
, cloudpickle
, dill
, dm-tree
, jaxlib
, pytest-xdist
, pytestCheckHook
, tensorflow
, bsuite
, frozendict
, dm-env
, scikitimage
, rlax
, distrax
, tensorflow-probability
, optax }:
buildPythonPackage rec {
pname = "dm-haiku-tests";
inherit (dm-haiku) version;
src = dm-haiku.testsout;
dontBuild = true;
dontInstall = true;
checkInputs = [
bsuite
chex
cloudpickle
dill
distrax
dm-env
dm-haiku
dm-tree
frozendict
jaxlib
pytest-xdist
pytestCheckHook
optax
rlax
scikitimage
tensorflow
tensorflow-probability
];
disabledTests = [
# See https://github.com/deepmind/dm-haiku/issues/366.
"test_jit_Recurrent"
# Assertion errors
"test_connect_conv_padding_function_same0"
"test_connect_conv_padding_function_valid0"
"test_connect_conv_padding_function_same1"
"test_connect_conv_padding_function_same2"
"test_connect_conv_padding_function_valid1"
"test_connect_conv_padding_function_valid2"
"test_invalid_axis_ListString"
"test_invalid_axis_String"
"test_simple_case"
"test_simple_case_with_scale"
"test_slice_axis"
"test_zero_inputs"
];
}