mirror of
https://github.com/NixOS/nixpkgs.git
synced 2024-09-11 15:08:33 +01:00
nixos/spark: add test
This commit is contained in:
parent
dd987c2dbe
commit
13839b0022
28
nixos/tests/spark/default.nix
Normal file
28
nixos/tests/spark/default.nix
Normal file
|
@ -0,0 +1,28 @@
|
|||
import ../make-test-python.nix ({...}: {
|
||||
name = "spark";
|
||||
|
||||
nodes = {
|
||||
worker = { nodes, pkgs, ... }: {
|
||||
virtualisation.memorySize = 1024;
|
||||
services.spark.worker = {
|
||||
enable = true;
|
||||
master = "master:7077";
|
||||
};
|
||||
};
|
||||
master = { config, pkgs, ... }: {
|
||||
services.spark.master = {
|
||||
enable = true;
|
||||
bind = "0.0.0.0";
|
||||
};
|
||||
networking.firewall.allowedTCPPorts = [ 22 7077 8080 ];
|
||||
};
|
||||
};
|
||||
|
||||
testScript = ''
|
||||
master.wait_for_unit("spark-master.service")
|
||||
worker.wait_for_unit("spark-worker.service")
|
||||
worker.copy_from_host( "${./spark_sample.py}", "/spark_sample.py" )
|
||||
assert "<title>Spark Master at spark://" in worker.succeed("curl -sSfkL http://master:8080/")
|
||||
worker.succeed("spark-submit --master spark://master:7077 --executor-memory 512m --executor-cores 1 /spark_sample.py")
|
||||
'';
|
||||
})
|
40
nixos/tests/spark/spark_sample.py
Normal file
40
nixos/tests/spark/spark_sample.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from pyspark.sql import Row, SparkSession
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.functions import udf
|
||||
from pyspark.sql.types import *
|
||||
from pyspark.sql.functions import explode
|
||||
|
||||
def explode_col(weight):
|
||||
return int(weight//10) * [10.0] + ([] if weight%10==0 else [weight%10])
|
||||
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
|
||||
dataSchema = [
|
||||
StructField("feature_1", FloatType()),
|
||||
StructField("feature_2", FloatType()),
|
||||
StructField("bias_weight", FloatType())
|
||||
]
|
||||
|
||||
data = [
|
||||
Row(0.1, 0.2, 10.32),
|
||||
Row(0.32, 1.43, 12.8),
|
||||
Row(1.28, 1.12, 0.23)
|
||||
]
|
||||
|
||||
df = spark.createDataFrame(spark.sparkContext.parallelize(data), StructType(dataSchema))
|
||||
|
||||
normalizing_constant = 100
|
||||
sum_bias_weight = df.select(F.sum('bias_weight')).collect()[0][0]
|
||||
normalizing_factor = normalizing_constant / sum_bias_weight
|
||||
df = df.withColumn('normalized_bias_weight', df.bias_weight * normalizing_factor)
|
||||
df = df.drop('bias_weight')
|
||||
df = df.withColumnRenamed('normalized_bias_weight', 'bias_weight')
|
||||
|
||||
my_udf = udf(lambda x: explode_col(x), ArrayType(FloatType()))
|
||||
df1 = df.withColumn('explode_val', my_udf(df.bias_weight))
|
||||
df1 = df1.withColumn("explode_val_1", explode(df1.explode_val)).drop("explode_val")
|
||||
df1 = df1.drop('bias_weight').withColumnRenamed('explode_val_1', 'bias_weight')
|
||||
|
||||
df1.show()
|
||||
|
||||
assert(df1.count() == 12)
|
Loading…
Reference in a new issue