Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十五)Spark编写UDF、UDAF、Agg函数
Spark Sql提供了丰富的内置函数让开发者来使用,但实际开发业务场景可能很复杂,内置函数不能够满足业务需求,因此spark sql提供了可扩展的内置函数。
UDF:是普通函数,输入一个或多个参数,返回一个值。比如:len(),isnull()
UDAF:是聚合函数,输入一组值,返回一个聚合结果。比如:max(),avg(),sum()
Spark编写UDF函数
下边的例子是在spark2.0之前的示例:例子中展示只有一个参数输入,和一个参数输出的UDF。
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDF1 {
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[2]");
sparkConf.setAppName("spark udf test");
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
@SuppressWarnings("deprecation")
SQLContext sqlContext=new SQLContext(javaSparkContext);
JavaRDD<String> javaRDD = javaSparkContext.parallelize(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"));
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
return RowFactory.create(fields);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sqlContext.createDataFrame(rowRDD, schema);
ds.createOrReplaceTempView("user"); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
sqlContext.udf().register("strLength", new UDF1<String, Integer>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public Integer call(String t1) throws Exception {
return t1.length();
}
}, DataTypes.IntegerType); Dataset<Row> rows = sqlContext.sql("select id,name,strLength(name) as length from user");
rows.show(); javaSparkContext.stop();
}
}
输出效果:
+---+--------+------+
| id| name|length|
+---+--------+------+
| 1|zhangsan| 8|
| 2| lisi| 4|
| 3| wangwu| 6|
| 4| zhaoliu| 7|
+---+--------+------+
上边使用UDF展示了:单个输入,单个输出的函数。那么下边将会展示使用spark2.0实现三个输入,一个输出的UDF函数。
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF3;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDF2 {
public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"), Encoders.STRING()); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
sparkSession.udf().register("strLength", new UDF1<String, Integer>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public Integer call(String t1) throws Exception {
return t1.length();
}
}, DataTypes.IntegerType);
sparkSession.udf().register("strConcat", new UDF3<String, String, String, String>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public String call(String combChar, String t1, String t2) throws Exception {
return t1 + combChar + t2;
}
}, DataTypes.StringType); showByStruct(sparkSession, row);
System.out.println("==========================================");
showBySchema(sparkSession, row); sparkSession.stop();
} private static void showBySchema(SparkSession sparkSession, Dataset<String> row) {
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
return RowFactory.create(fields);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('+',id,name) as str from user");
rows.show();
} private static void showByStruct(SparkSession sparkSession, Dataset<String> row) {
JavaRDD<Person> map = row.javaRDD().map(Person::parsePerson);
Dataset<Row> persons = sparkSession.createDataFrame(map, Person.class);
persons.show(); persons.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('-',id,name) as str from user");
rows.show();
}
}
Person.java
package com.dx.streaming.producer; import java.io.Serializable; public class Person implements Serializable{
private String id;
private String name; public Person(String id, String name) {
this.id = id;
this.name = name;
} public String getId() {
return id;
} public void setId(String id) {
this.id = id;
} public String getName() {
return name;
} public void setName(String name) {
this.name = name;
} public static Person parsePerson(String line) {
String[] fields = line.split(",");
Person person = new Person(fields[0], fields[1]);
return person;
}
}
需要注意的地方,我们全局udf函数只需要注册一次,就允许多次调用。
输出效果:
+---+--------+
| id| name|
+---+--------+
| 1|zhangsan|
| 2| lisi|
| 3| wangwu|
| 4| zhaoliu|
+---+--------+ +---+--------+------+----------+
| id| name|length| str|
+---+--------+------+----------+
| 1|zhangsan| 8|1-zhangsan|
| 2| lisi| 4| 2-lisi|
| 3| wangwu| 6| 3-wangwu|
| 4| zhaoliu| 7| 4-zhaoliu|
+---+--------+------+----------+ ========================================== +---+--------+
| id| name|
+---+--------+
| 1|zhangsan|
| 2| lisi|
| 3| wangwu|
| 4| zhaoliu|
+---+--------+ +---+--------+------+----------+
| id| name|length| str|
+---+--------+------+----------+
| 1|zhangsan| 8|1+zhangsan|
| 2| lisi| 4| 2+lisi|
| 3| wangwu| 6| 3+wangwu|
| 4| zhaoliu| 7| 4+zhaoliu|
+---+--------+------+----------+
相信认真阅读的话,通过上边的两个示例,就可以掌握其用法。
Spark编写UDAF函数
自定义聚合函数需要实现UserDefinedAggregateFunction,以下是该抽象类的定义:
package org.apache.spark.sql.expressions import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.types._
import org.apache.spark.annotation.Experimental /**
* :: Experimental ::
* The base class for implementing user-defined aggregate functions (UDAF).
*/
@Experimental
abstract class UserDefinedAggregateFunction extends Serializable { /**
* A [[StructType]] represents data types of input arguments of this aggregate function.
* For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
* with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* The name of a field of this [[StructType]] is only used to identify the corresponding
* input argument. Users can choose names to identify the input arguments.
*/
//输入参数的数据类型定义
def inputSchema: StructType /**
* A [[StructType]] represents data types of values in the aggregation buffer.
* For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
* (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
* the returned [[StructType]] will look like
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* The name of a field of this [[StructType]] is only used to identify the corresponding
* buffer value. Users can choose names to identify the input arguments.
*/
//聚合的中间过程中产生的数据的数据类型定义
def bufferSchema: StructType /**
* The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
*/
//聚合结果的数据类型定义
def dataType: DataType /**
* Returns true if this function is deterministic, i.e. given the same input,
* always return the same output.
*/
//一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。
def deterministic: Boolean /**
* Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
*
* The contract should be that applying the merge function on two initial buffers should just
* return the initial buffer itself, i.e.
* `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
*/
//设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
def initialize(buffer: MutableAggregationBuffer): Unit
/**
* Updates the given aggregation buffer `buffer` with new input data from `input`.
*
* This is called once per input row.
*/
//用输入数据input更新buffer值,类似于combineByKey
def update(buffer: MutableAggregationBuffer, input: Row): Unit
/**
* Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
*
* This is called when we merge two partially aggregated data together.
*/
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
/**
* Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
* aggregation buffer.
*/
//计算并返回最终的聚合结果
def evaluate(buffer: Row): Any
/**
* Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
*/
//所有输入数据进行聚合
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression2(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = false)
Column(aggregateExpression)
} /**
* Creates a [[Column]] for this UDAF using the distinct values of the given
* [[Column]]s as input arguments.
*/
//所有输入数据去重后进行聚合
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression2(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = true)
Column(aggregateExpression)
}
} /**
* :: Experimental ::
* A [[Row]] representing an mutable aggregation buffer.
*
* This is not meant to be extended outside of Spark.
*/
@Experimental
abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */
def update(i: Int, value: Any): Unit
}
实现单列求平均数的聚合函数:
package com.dx.streaming.producer; import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType; public class SimpleAvg extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L; @Override
public StructType inputSchema() {
StructType structType= new StructType().add("myinput",DataTypes.DoubleType);
return structType;
} @Override
public StructType bufferSchema() {
StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
return structType;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return true;
} //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
} /**
* partitions内部combine
* */
//用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0, buffer.getLong(0)+1); // 條目數+1
buffer.update(1, buffer.getDouble(1)+input.getDouble(0)); // 输入汇总
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
* */
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0)); // 條目數合併
buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
} //计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double avg = buffer.getDouble(1) / buffer.getLong(0);
Double avgFormat = Double.parseDouble(String.format("%.2f", avg)); return avgFormat;
}
}
下边展示下如何使用自定义的UDAF函数:
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF1 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80",
"2,zhangsan,History,87",
"3,zhangsan,Chinese,88",
"4,zhangsan,Chemistry,96",
"5,lisi,English,70",
"6,lisi,Chinese,74",
"7,lisi,History,75",
"8,lisi,Chemistry,77",
"9,lisi,Physics,79",
"10,lisi,Biology,82",
"11,wangwu,English,96",
"12,wangwu,Chinese,98",
"13,wangwu,History,91",
"14,zhaoliu,English,68",
"15,zhaoliu,Chinese,66"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve=Double.parseDouble(fields[3]);
return RowFactory.create(id,name,subject,achieve);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new SimpleAvg();
sparkSession.udf().register("avg_format", udaf); Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve) avg_achieve from user group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve) avg_achieve from user group by name");
rows2.show();
} }
输出结果:
+---+--------+---------+-------+
| id| name| subject|achieve|
+---+--------+---------+-------+
| 1|zhangsan| English| 80.0|
| 2|zhangsan| History| 87.0|
| 3|zhangsan| Chinese| 88.0|
| 4|zhangsan|Chemistry| 96.0|
| 5| lisi| English| 70.0|
| 6| lisi| Chinese| 74.0|
| 7| lisi| History| 75.0|
| 8| lisi|Chemistry| 77.0|
| 9| lisi| Physics| 79.0|
| 10| lisi| Biology| 82.0|
| 11| wangwu| English| 96.0|
| 12| wangwu| Chinese| 98.0|
| 13| wangwu| History| 91.0|
| 14| zhaoliu| English| 68.0|
| 15| zhaoliu| Chinese| 66.0|
+---+--------+---------+-------+ +--------+-----------------+
| name| avg_achieve|
+--------+-----------------+
| wangwu| 95.0|
| zhaoliu| 67.0|
|zhangsan| 87.75|
| lisi|76.16666666666667|
+--------+-----------------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 95.0|
| zhaoliu| 67.0|
|zhangsan| 87.75|
| lisi| 76.17|
+--------+-----------+
实现多列之和,再求平均数的UDAF聚合函数:
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF1 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80,89",
"2,zhangsan,History,87,88",
"3,zhangsan,Chinese,88,87",
"4,zhangsan,Chemistry,96,95",
"5,lisi,English,70,75",
"6,lisi,Chinese,74,67",
"7,lisi,History,75,80",
"8,lisi,Chemistry,77,70",
"9,lisi,Physics,79,80",
"10,lisi,Biology,82,83",
"11,wangwu,English,96,84",
"12,wangwu,Chinese,98,64",
"13,wangwu,History,91,92",
"14,zhaoliu,English,68,80",
"15,zhaoliu,Chinese,66,69"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve1=Double.parseDouble(fields[3]);
Double achieve2=Double.parseDouble(fields[4]);
return RowFactory.create(id,name,subject,achieve1,achieve2);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new MutilAvg(2);
sparkSession.udf().register("avg_format", udaf); Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve1+achieve2) avg_achieve from user group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve1,achieve2) avg_achieve from user group by name");
rows2.show();
}
}
上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilAvg实现的就是一个多列求和之后在进行求平均的使用。
MutilAvg.java(udaf函数):
package com.dx.streaming.producer; import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType; public class MutilAvg extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L;
private int columnSize=1; public MutilAvg(int columnSize){
this.columnSize=columnSize;
} @Override
public StructType inputSchema() {
StructType structType= new StructType();
for(int i=0;i<columnSize;i++){
structType.add("myinput"+i,DataTypes.DoubleType);
}
return structType;
} @Override
public StructType bufferSchema() {
StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
return structType;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return true;
} //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
} /**
* partitions内部combine
* */
//用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0, buffer.getLong(0)+1); // 條目數+1 // 输入一行包含多列,因此需要把铜一行的多列合并。
Double currentLineSumValue= 0d;
for(int i=0;i<columnSize;i++){
currentLineSumValue+=input.getDouble(i);
} buffer.update(1, buffer.getDouble(1)+currentLineSumValue); // 输入汇总
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
* */
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0)); // 條目數合併
buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
} //计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double avg = buffer.getDouble(1) / buffer.getLong(0);
Double avgFormat = Double.parseDouble(String.format("%.2f", avg)); return avgFormat;
}
}
测试输出:
+---+--------+---------+--------+--------+
| id| name| subject|achieve1|achieve2|
+---+--------+---------+--------+--------+
| 1|zhangsan| English| 80.0| 89.0|
| 2|zhangsan| History| 87.0| 88.0|
| 3|zhangsan| Chinese| 88.0| 87.0|
| 4|zhangsan|Chemistry| 96.0| 95.0|
| 5| lisi| English| 70.0| 75.0|
| 6| lisi| Chinese| 74.0| 67.0|
| 7| lisi| History| 75.0| 80.0|
| 8| lisi|Chemistry| 77.0| 70.0|
| 9| lisi| Physics| 79.0| 80.0|
| 10| lisi| Biology| 82.0| 83.0|
| 11| wangwu| English| 96.0| 84.0|
| 12| wangwu| Chinese| 98.0| 64.0|
| 13| wangwu| History| 91.0| 92.0|
| 14| zhaoliu| English| 68.0| 80.0|
| 15| zhaoliu| Chinese| 66.0| 69.0|
+---+--------+---------+--------+--------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 175.0|
| zhaoliu| 141.5|
|zhangsan| 177.5|
| lisi| 152.0|
+--------+-----------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 175.0|
| zhaoliu| 141.5|
|zhangsan| 177.5|
| lisi| 152.0|
+--------+-----------+
实现多列分别求最大值,之后再从多列中最大值中找出一个最大的值的UDAF聚合函数:
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF2 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80,89",
"2,zhangsan,History,87,88",
"3,zhangsan,Chinese,88,87",
"4,zhangsan,Chemistry,96,95",
"5,lisi,English,70,75",
"6,lisi,Chinese,74,67",
"7,lisi,History,75,80",
"8,lisi,Chemistry,77,70",
"9,lisi,Physics,79,80",
"10,lisi,Biology,82,83",
"11,wangwu,English,96,84",
"12,wangwu,Chinese,98,64",
"13,wangwu,History,91,92",
"14,zhaoliu,English,68,80",
"15,zhaoliu,Chinese,66,69"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve1=Double.parseDouble(fields[3]);
Double achieve2=Double.parseDouble(fields[4]);
return RowFactory.create(id,name,subject,achieve1,achieve2);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show(); ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new MutilMax(2,0);
sparkSession.udf().register("max_vals", udaf); Dataset<Row> rows1 = sparkSession.sql(""
+ "select name,max(achieve) as max_achieve "
+ "from "
+ "("
+ "select name,max(achieve1) achieve from user group by name "
+ "union all "
+ "select name,max(achieve2) achieve from user group by name "
+ ") t10 "
+ "group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,max_vals(achieve1,achieve2) as max_achieve from user group by name");
rows2.show();
}
}
上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilMax实现的就是一个多列分别求出各自列的最大值,再从这些列的最大值中找出最大的一个值作为返回的最大值。
MutilMax.java(udaf函数):
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.List; import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class MutilMax extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L;
private int columnSize = 1;
private Double defaultValue; public MutilMax(int columnSize, double defaultValue) {
this.columnSize = columnSize;
this.defaultValue = defaultValue;
} @Override
public StructType inputSchema() {
List<StructField> inputFields = new ArrayList<StructField>();
for (int i = 0; i < this.columnSize; i++) {
inputFields.add(DataTypes.createStructField("myinput" + i, DataTypes.DoubleType, true));
}
StructType inputSchema = DataTypes.createStructType(inputFields);
return inputSchema;
} @Override
public StructType bufferSchema() {
List<StructField> bufferFields = new ArrayList<StructField>();
for (int i = 0; i < this.columnSize; i++) {
bufferFields.add(DataTypes.createStructField("mymax" + i, DataTypes.DoubleType, true));
}
StructType bufferSchema = DataTypes.createStructType(bufferFields);
return bufferSchema;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return false;
} // 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
// 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
for (int i = 0; i < this.columnSize; i++) {
buffer.update(i, 0d);
}
} /**
* partitions内部combine
*/
// 用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
for (int i = 0; i < this.columnSize; i++) {
if( buffer.getDouble(i) >input.getDouble(i)){
buffer.update(i, buffer.getDouble(i));
}else{
buffer.update(i, input.getDouble(i));
}
}
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
*/
// 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
// 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
for (int i = 0; i < this.columnSize; i++) {
if( buffer1.getDouble(i) >buffer2.getDouble(i)){
buffer1.update(i, buffer1.getDouble(i));
}else{
buffer1.update(i, buffer2.getDouble(i));
}
}
} // 计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double max = Double.MIN_VALUE;
for (int i = 0; i < this.columnSize; i++) {
if (buffer.getDouble(i) > max) {
max = buffer.getDouble(i);
}
} if (max == Double.MIN_VALUE) {
max = this.defaultValue;
} return max;
} }
打印结果:
+---+--------+---------+--------+--------+
| id| name| subject|achieve1|achieve2|
+---+--------+---------+--------+--------+
| 1|zhangsan| English| 80.0| 89.0|
| 2|zhangsan| History| 87.0| 88.0|
| 3|zhangsan| Chinese| 88.0| 87.0|
| 4|zhangsan|Chemistry| 96.0| 95.0|
| 5| lisi| English| 70.0| 75.0|
| 6| lisi| Chinese| 74.0| 67.0|
| 7| lisi| History| 75.0| 80.0|
| 8| lisi|Chemistry| 77.0| 70.0|
| 9| lisi| Physics| 79.0| 80.0|
| 10| lisi| Biology| 82.0| 83.0|
| 11| wangwu| English| 96.0| 84.0|
| 12| wangwu| Chinese| 98.0| 64.0|
| 13| wangwu| History| 91.0| 92.0|
| 14| zhaoliu| English| 68.0| 80.0|
| 15| zhaoliu| Chinese| 66.0| 69.0|
+---+--------+---------+--------+--------+ +--------+-----------+
| name|max_achieve|
+--------+-----------+
| wangwu| 98.0|
| zhaoliu| 80.0|
|zhangsan| 96.0|
| lisi| 83.0|
+--------+-----------+ +--------+-----------+
| name|max_achieve|
+--------+-----------+
| wangwu| 98.0|
| zhaoliu| 80.0|
|zhangsan| 96.0|
| lisi| 83.0|
+--------+-----------+
Spark编写Agg函数
实现一个avg函数:
第一步:定义一个Average,用来存储count,sum;
import java.io.Serializable; public class Average implements Serializable {
private long sum;
private long count; // Constructors, getters, setters...
public long getSum() {
return sum;
} public void setSum(long sum) {
this.sum = sum;
} public long getCount() {
return count;
} public void setCount(long count) {
this.count = count;
} public Average() { } public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}
}
第二步:定义一个Employee,存储员工信息:员工名称、员工薪资;
import java.io.Serializable; public class Employee implements Serializable {
private String name;
private long salary; // Constructors, getters, setters...
public String getName() {
return name;
} public void setName(String name) {
this.name = name;
} public long getSalary() {
return salary;
} public void setSalary(long salary) {
this.salary = salary;
} public Employee() {
} public Employee(String name, long salary) {
this.name = name;
this.salary = salary;
}
}
第三步:定义一个Agg,实现对员工的薪资avg功能;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator; public class MyAverage extends Aggregator<Employee, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
@Override
public Average zero() {
return new Average(0L, 0L);
} // Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
@Override
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getSalary();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
} // Merge two intermediate values
@Override
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
} // Transform the output of the reduction
@Override
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
} // Specifies the Encoder for the intermediate value type
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
} // Specifies the Encoder for the final output value type
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
第四步:spark调用agg,验证。
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.*;; import java.util.ArrayList;
import java.util.List; public class SparkClient {
public static void main(String[] args) {
final SparkSession spark = SparkSession.builder().master("local[*]").appName("test_agg").getOrCreate();
final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext()); List<Employee> employeeList = new ArrayList<Employee>();
employeeList.add(new Employee("Michael", 3000L));
employeeList.add(new Employee("Andy", 4500L));
employeeList.add(new Employee("Justin", 3500L));
employeeList.add(new Employee("Berta", 4000L)); JavaRDD<Employee> rows = ctx.parallelize(employeeList);
Dataset<Employee> ds = spark.createDataFrame(rows, Employee.class).map(new MapFunction<Row, Employee>() {
@Override
public Employee call(Row row) throws Exception {
return new Employee(row.getString(0), row.getLong(1));
}
}, Encoders.bean(Employee.class)); ds.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+ MyAverage myAverage = new MyAverage();
// Convert the function to a `TypedColumn` and give it a name
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
Dataset<Double> result = ds.select(averageSalary);
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
}
}
输出:
+-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+ +--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+
参考:
https://www.cnblogs.com/LHWorldBlog/p/8432210.html
https://blog.csdn.net/kwu_ganymede/article/details/50462020
https://my.oschina.net/cloudcoder/blog/640009
https://blog.csdn.net/xgjianstart/article/details/54956413
Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十五)Spark编写UDF、UDAF、Agg函数的更多相关文章
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(五)针对hadoop2.9.0启动之后发现slave上正常启动了DataNode,DataManager,但是过了几秒后发现DataNode被关闭
启动之后发现slave上正常启动了DataNode,DataManager,但是过了几秒后发现DataNode被关闭 以slave1上错误日期为例查看错误信息: /logs/hadoop-spark- ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(二)安装hadoop2.9.0
如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(二十一)NIFI1.7.1安装
一.nifi基本配置 1. 修改各节点主机名,修改/etc/hosts文件内容. 192.168.0.120 master 192.168.0.121 slave1 192.168.0.122 sla ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十三)kafka+spark streaming打包好的程序提交时提示虚拟内存不足(Container is running beyond virtual memory limits. Current usage: 119.5 MB of 1 GB physical memory used; 2.2 GB of 2.1 G)
异常问题:Container is running beyond virtual memory limits. Current usage: 119.5 MB of 1 GB physical mem ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十二)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网。
Centos7出现异常:Failed to start LSB: Bring up/down networking. 按照<Kafka:ZK+Kafka+Spark Streaming集群环境搭 ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十一)定制一个arvo格式文件发送到kafka的topic,通过Structured Streaming读取kafka的数据
将arvo格式数据发送到kafka的topic 第一步:定制avro schema: { "type": "record", "name": ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十)安装hadoop2.9.0搭建HA
如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(九)安装kafka_2.11-1.1.0
如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(八)安装zookeeper-3.4.12
如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...
- Kafka:ZK+Kafka+Spark Streaming集群环境搭建(三)安装spark2.2.1
如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...
随机推荐
- OPTIMIZE TABLE ipc_analysisdatasyn, ipc_analysisdatatkv,ipc_autoupdateset, ipc_equipmentwaring,ipc_fguid, ipc_receivedata, ipc_senddata, tb_qualitativeanalysis, tb_quantifyresult, tb_quantifyresulthis
OPTIMIZE TABLE ipc_analysisdatasyn, ipc_analysisdatatkv,ipc_autoupdateset, ipc_equipmentwaring,ipc_f ...
- WebLogic使用总结(四)——WebLogic部署Web应用
一.打包Web应用 首先将要部署到WebLogic的Web应用打包成war包,具体操作步骤如下图所示: 选中要打包的[oams]项目→[Export...]
- In-Place upgrade to Team Foundation Server (TFS) 2015 from TFS 2013Team Foundation Server TFS TFS 2015 TFS upgrade TFS with Sharepoint
This upgrade document gives detailed step by step procedure for the In-Place upgrade from TFS 2013 t ...
- PHP扩展迁移为PHP7扩展兼容性问题记录
PHP7扩展编写的时候,提供的一些内核方法和之前的PHP之前的版本并不能完全兼容.有不少方法参数做了调整.下面是在迁移过程中遇到的一些问题.记录下来,避免大家再踩坑. add_assoc_string ...
- 线程池框架executor
Eexecutor作为灵活且强大的异步执行框架,其支持多种不同类型的任务执行策略,提供了一种标准的方法将任务的提交过程和执行过程解耦开发,基于生产者-消费者模式,其提交任务的线程相当于生产者,执行任务 ...
- Java集合框架顶层接口collectiion接口
如何使用迭代器 通常情况下,你会希望遍历一个集合中的元素.例如,显示集合中的每个元素. 一般遍历数组都是采用for循环或者增强for,这两个方法也可以用在集合框架,但是还有一种方法是采用迭代器遍历集合 ...
- how to use fiddler and wireshark to decrypt ssl
原文地址: http://security14.blogspot.jp/2010/07/how-to-use-fiddler-and-wireshark-to.html Requirements2 C ...
- android中Bitmap的放大和缩小的方法
android中Bitmap的放大和缩小的方法 时间 2013-06-20 19:02:34 CSDN博客原文 http://blog.csdn.net/ada168855/article/det ...
- [转]“菜”鸟理解.NET Framework(CLI,CLS,CTS,CLR,FCL,BCL)
既然要学.NET,就要先认识认识她,我不喜欢大段大段文字的东西,自己通过理解,画个图,来看看.NET的沉鱼落雁,闭月羞花之容. 最下层蓝色部分是.NET Framework的基础,也是所有应用软件的基 ...
- java 8 stream特性
在Java 8的新功能特性中,最棒的特性就是允许我们去表达我们想要完成什么而不是要怎样做.这正是循环的不足之处.要确保循环的灵活性是需要付出代价的.return.break 或者 continue都会 ...