基于Spark ML的Titanic Challenge (Top 6%)
下面代码按照之前参加Kaggle的python代码改写,只完成了模型的训练过程,还需要对test集的数据进行转换和对test集进行预测。
scala 2.11.12
spark 2.2.2
package ML.Titanic
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.Bucketizer
import org.apache.spark.ml.feature.QuantileDiscretizer
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.tuning.{TrainValidationSplit, TrainValidationSplitModel}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.types._
/**
* GBTClassifier for predicting survival in the Titanic ship
*/
object TitanicChallenge {
def main(args: Array[String]) {
val spark = SparkSession.builder.
master("local[*]")
.appName("example")
.config("spark.sql.shuffle.partitions", 20)
.config("spark.default.parallelism", 20)
.config("spark.driver.memory", "4G")
.config("spark.memory.fraction", 0.75)
.getOrCreate()
val sc = spark.sparkContext
spark.sparkContext.setLogLevel("ERROR")
val schemaArray = StructType(Array(
StructField("PassengerId", IntegerType, true),
StructField("Survived", IntegerType, true),
StructField("Pclass", IntegerType, true),
StructField("Name", StringType, true),
StructField("Sex", StringType, true),
StructField("Age", FloatType, true),
StructField("SibSp", IntegerType, true),
StructField("Parch", IntegerType, true),
StructField("Ticket", StringType, true),
StructField("Fare", FloatType, true),
StructField("Cabin", StringType, true),
StructField("Embarked", StringType, true)
))
val path = "Titanic/"
val df = spark.read
.option("header", "true")
.schema(schemaArray)
.csv(path + "train.csv")
.drop("PassengerId")
// df.cache()
val utils = new TitanicChallenge(spark)
val df2 = utils.transCabin(df)
val df3 = utils.transTicket(sc, df2)
val df4 = utils.transEmbarked(df3)
val df5 = utils.extractTitle(sc, df4)
val df6 = utils.transAge(sc, df5)
val df7 = utils.categorizeAge(df6)
val df8 = utils.createFellow(df7)
val df9 = utils.categorizeFellow(df8)
val df10 = utils.extractFName(df9)
val df11 = utils.transFare(df10)
val prePipelineDF = df11.select("Survived", "Pclass", "Sex",
"Age_categorized", "fellow_type", "Fare_categorized",
"Embarked", "Cabin", "Ticket",
"Title", "family_type")
// prePipelineDF.show(1)
// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+
// |Survived|Pclass| Sex|Age_categorized|fellow_type|Fare_categorized|Embarked|Cabin|Ticket|Title|family_type|
// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+
// | 0| 3|male| 3.0| Small| 0.0| S| U| 0| Mr| 0|
// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+
val (df_indexed, colsTrain) = utils.index_onehot(prePipelineDF)
df_indexed.cache()
//训练模型
val validatorModel = utils.trainData(df_indexed, colsTrain)
//打印最优模型的参数
val bestModel = validatorModel.bestModel
println(bestModel.asInstanceOf[PipelineModel].stages.last.extractParamMap)
//打印各模型的成绩和参数
val paramsAndMetrics = validatorModel.validationMetrics
.zip(validatorModel.getEstimatorParamMaps)
.sortBy(-_._1)
paramsAndMetrics.foreach { case (metric, params) =>
println(metric)
println(params)
println()
}
validatorModel.write.overwrite().save(path + "Titanic_gbtc")
spark.stop()
}
}
class TitanicChallenge(private val spark: SparkSession) extends Serializable {
import spark.implicits._
//Cabin,用“U”填充null,并提取Cabin的首字母
def transCabin(df: Dataset[Row]): Dataset[Row] = {
df.na.fill("U", Seq("Cabin"))
.withColumn("Cabin", substring($"Cabin", 0, 1))
}
//
def transTicket(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = {
////提取船票的号码,如“A/5 21171”中的21171
val medDF1 = df.withColumn("Ticket", split($"Ticket", " "))
.withColumn("Ticket", $"Ticket"(size($"Ticket").minus(1)))
.filter($"Ticket" =!= "LINE")//去掉某种特殊的船票
//对船票号进行分类,小于四位号码的为“1”,四位号码的以第一个数字开头,后面接上“0”,大于4位号码的,取前三个数字开头。如21171变为211
val ticketTransUdf = udf((ticket: String) => {
if (ticket.length < 4) {
"1"
} else if (ticket.length == 4){
ticket(0)+"0"
} else {
ticket.slice(0, 3)
}
})
val medDF2 = medDF1.withColumn("Ticket", ticketTransUdf($"Ticket"))
//将数量小于等于5的类别统一归为“0”。先统计小于5的名单,然后用udf进行转换。
val filterList = medDF2.groupBy($"Ticket").count()
.filter($"count" <= 5)
.map(row => row.getString(0))
.collect.toList
val filterList_bc = sc.broadcast(filterList)
val ticketTransAdjustUdf = udf((subticket: String) => {
if (filterList_bc.value.contains(subticket)) "0"
else subticket
})
medDF2.withColumn("Ticket", ticketTransAdjustUdf($"Ticket"))
}
//用“S”填充null
def transEmbarked(df: Dataset[Row]): Dataset[Row] = {
df.na.fill("S", Seq("Embarked"))
}
def extractTitle(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = {
val regex = ".*, (.*?)\\..*"
//对头衔进行归类
val titlesMap = Map(
"Capt"-> "Officer",
"Col"-> "Officer",
"Major"-> "Officer",
"Jonkheer"-> "Royalty",
"Don"-> "Royalty",
"Sir" -> "Royalty",
"Dr"-> "Officer",
"Rev"-> "Officer",
"the Countess"->"Royalty",
"Mme"-> "Mrs",
"Mlle"-> "Miss",
"Ms"-> "Mrs",
"Mr" -> "Mr",
"Mrs" -> "Mrs",
"Miss" -> "Miss",
"Master" -> "Master",
"Lady" -> "Royalty"
)
val titlesMap_bc = sc.broadcast(titlesMap)
df.withColumn("Title", regexp_extract(($"Name"), regex, 1))
.na.replace("Title", titlesMap_bc.value)
}
//根据null age的records对应的Pclass和Name_final分组后的平均来填充缺失age。
// 首先,生成分组key,并获取分组后的平均年龄map。然后广播map,当Age为null时,用udf返回需要填充的值。
def transAge(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = {
val medDF = df.withColumn("Pclass_Title_key", concat($"Title", $"Pclass"))
val meanAgeMap = medDF.groupBy("Pclass_Title_key")
.mean("Age")
.map(row => (row.getString(0), row.getDouble(1)))
.collect().toMap
val meanAgeMap_bc = sc.broadcast(meanAgeMap)
val fillAgeUdf = udf((comb_key: String) => meanAgeMap_bc.value.getOrElse(comb_key, 0.0))
medDF.withColumn("Age", when($"Age".isNull, fillAgeUdf($"Pclass_Title_key")).otherwise($"Age"))
}
//对Age进行分类
def categorizeAge(df: Dataset[Row]): Dataset[Row] = {
val ageBucketBorders = 0.0 +: (10.0 to 60.0 by 5.0).toArray :+ 150.0
val ageBucketer = new Bucketizer().setSplits(ageBucketBorders).setInputCol("Age").setOutputCol("Age_categorized")
ageBucketer.transform(df).drop("Pclass_Title_key")
}
//将SibSp和Parch相加,得出同行人数
def createFellow(df: Dataset[Row]): Dataset[Row] = {
df.withColumn("fellow", $"SibSp" + $"Parch")
}
//fellow_type, 对fellow进行分类。此处其实可以留到pipeline部分一次性完成。
def categorizeFellow(df: Dataset[Row]): Dataset[Row] = {
df.withColumn("fellow_type", when($"fellow" === 0, "Alone")
.when($"fellow" <= 3, "Small")
.otherwise("Large"))
}
def extractFName(df: Dataset[Row]): Dataset[Row] = {
//检查df是否有Survived和fellow列
if (!df.columns.contains("Survived") || !df.columns.contains("fellow")){
throw new IllegalArgumentException(
"""
|Check if the argument is a training set or if this training set contains column named \"fellow\"
""".stripMargin)
}
//FName,提取家庭名称。例如:"Johnston, Miss. Catherine Helen ""Carrie""" 提取出Johnston
// 由于spark的读取csv时,如果有引号,读取就会出现多余的引号,所以除了split逗号,还要再split一次引号。
val medDF = df
.withColumn("FArray", split($"Name", ","))
.withColumn("FName", expr("FArray[0]"))
.withColumn("FArray", split($"FName", "\""))
.withColumn("FName", $"FArray"(size($"FArray").minus(1)))
//family_type,分为三类,第一类是60岁以下女性遇难的家庭,第二类是18岁以上男性存活的家庭,第三类其他。
val femaleDiedFamily_filter = $"Sex" === "female" and $"Age" < 60 and $"Survived" === 0 and $"fellow" > 0
val maleSurvivedFamily_filter = $"Sex" === "male" and $"Age" >= 18 and $"Survived" === 1 and $"fellow" > 1
val resDF = medDF.withColumn("family_type", when(femaleDiedFamily_filter, 1)
.when(maleSurvivedFamily_filter, 2).otherwise(0))
//familyTable,家庭分类名单,用于后续test集的转化。此处用${FName}_${family_type}的形式保存。
resDF.filter($"family_type".isin(1,2))
.select(concat($"FName", lit("_"), $"family_type"))
.dropDuplicates()
.write.format("text").mode("overwrite").save("familyTable")
//如果需要直接收集成Map的话,可用下面代码。
// 此代码先利用mapPartitions对各分块的数据进行聚合,降低直接调用count而使driver挂掉的风险。
//另外新建一个默认Set是为了防止某个partition并没有数据的情况(出现概率可能比较少),
// 从而使得Set的类型变为Set[_>:Tuple]而不能直接flatten
// val familyMap = df10
// .filter($"family_type" === 1 || $"family_type" === 2)
// .select("FName", "family_type")
// .rdd
// .mapPartitions{iter => {
// if (!iter.isEmpty) {
// Iterator(iter.map(row => (row.getString(0), row.getInt(1))).toSet)}
// else Iterator(Set(("defualt", 9)))}
// }
// .collect()
// .flatten
// .toMap
resDF
}
//Fare。首先去掉缺失的(test集合中有一个,如果量多的话,也可以像Age那样通过头衔,年龄等因数来推断)
//然后对Fare进行分类
def transFare(df: Dataset[Row]): Dataset[Row] = {
val medDF = df.na.drop("any", Seq("Fare"))
val fareBucketer = new QuantileDiscretizer()
.setInputCol("Fare")
.setOutputCol("Fare_categorized")
.setNumBuckets(4)
fareBucketer.fit(medDF).transform(medDF)
}
def index_onehot(df: Dataset[Row]): Tuple2[Dataset[Row], Array[String]] = {
val stringCols = Array("Sex","fellow_type", "Embarked", "Cabin", "Ticket", "Title")
val subOneHotCols = stringCols.map(cname => s"${cname}_index")
val index_transformers: Array[org.apache.spark.ml.PipelineStage] = stringCols.map(
cname => new StringIndexer()
.setInputCol(cname)
.setOutputCol(s"${cname}_index")
.setHandleInvalid("skip")
)
val oneHotCols = subOneHotCols ++ Array("Pclass", "Age_categorized", "Fare_categorized", "family_type")
val vectorCols = oneHotCols.map(cname => s"${cname}_encoded")
val encode_transformers: Array[org.apache.spark.ml.PipelineStage] = oneHotCols.map(
cname => new OneHotEncoder()
.setInputCol(cname)
.setOutputCol(s"${cname}_encoded")
)
val pipelineStage = index_transformers ++ encode_transformers
val index_onehot_pipeline = new Pipeline().setStages(pipelineStage)
val index_onehot_pipelineModel = index_onehot_pipeline.fit(df)
val resDF = index_onehot_pipelineModel.transform(df).drop(stringCols:_*).drop(subOneHotCols:_*)
println(resDF.columns.size)
(resDF, vectorCols)
}
def trainData(df: Dataset[Row], vectorCols: Array[String]): TrainValidationSplitModel = {
//separate and model pipeline,包含划分label和features,机器学习模型的pipeline
val vectorAssembler = new VectorAssembler()
.setInputCols(vectorCols)
.setOutputCol("features")
val gbtc = new GBTClassifier()
.setLabelCol("Survived")
.setFeaturesCol("features")
.setPredictionCol("prediction")
val pipeline = new Pipeline().setStages(Array(vectorAssembler, gbtc))
val paramGrid = new ParamGridBuilder()
.addGrid(gbtc.stepSize, Seq(0.1))
.addGrid(gbtc.maxDepth, Seq(5))
.addGrid(gbtc.maxIter, Seq(20))
.build()
val multiclassEval = new MulticlassClassificationEvaluator()
.setLabelCol("Survived")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val tvs = new TrainValidationSplit()
.setTrainRatio(0.75)
.setEstimatorParamMaps(paramGrid)
.setEstimator(pipeline)
.setEvaluator(multiclassEval)
tvs.fit(df)
}
}
基于Spark ML的Titanic Challenge (Top 6%)的更多相关文章
- 使用spark ml pipeline进行机器学习
一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...
- spark ml 的例子
一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...
- spark ml pipeline构建机器学习任务
一.关于spark ml pipeline与机器学习一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的流 ...
- Spark ML Pipeline简介
Spark ML Pipeline基于DataFrame构建了一套High-level API,我们可以使用MLPipeline构建机器学习应用,它能够将一个机器学习应用的多个处理过程组织起来,通过在 ...
- 基于Spark的电影推荐系统(推荐系统~4)
第四部分-推荐系统-模型训练 本模块基于第3节 数据加工得到的训练集和测试集数据 做模型训练,最后得到一系列的模型,进而做 预测. 训练多个模型,取其中最好,即取RMSE(均方根误差)值最小的模型 说 ...
- 基于Spark ALS构建商品推荐引擎
基于Spark ALS构建商品推荐引擎 一般来讲,推荐引擎试图对用户与某类物品之间的联系建模,其想法是预测人们可能喜好的物品并通过探索物品之间的联系来辅助这个过程,让用户能更快速.更准确的获得所需 ...
- 推荐系统那点事 —— 基于Spark MLlib的特征选择
在机器学习中,一般都会按照下面几个步骤:特征提取.数据预处理.特征选择.模型训练.检验优化.那么特征的选择就很关键了,一般模型最后效果的好坏往往都是跟特征的选择有关系的,因为模型本身的参数并没有太多优 ...
- Spark ML下实现的多分类adaboost+naivebayes算法在文本分类上的应用
1. Naive Bayes算法 朴素贝叶斯算法算是生成模型中一个最经典的分类算法之一了,常用的有Bernoulli和Multinomial两种.在文本分类上经常会用到这两种方法.在词袋模型中,对于一 ...
- Spark ML源码分析之一 设计框架解读
本博客为作者原创,如需转载请注明参考 在深入理解Spark ML中的各类算法之前,先理一下整个库的设计框架,是非常有必要的,优秀的框架是对复杂问题的抽象和解剖,对这种抽象的学习本身 ...
随机推荐
- mysql常用命令介绍
mysql适用于在Internet上存取数据,支持多种平台 1.主键:唯一标识表中每行的这个列,没有主键更新或删除表中的特定行很困难. 2.连接mysql可以用Navicat 要读取数据库中的内容先要 ...
- MYSQL数据库迁移到ORACLE数据库
一.环境和需求1.环境 MySQL数据库服务器: OS version:Linux 5.3 for 64 bit mysql Server version: 5.0.45 Oracle数据库服务器: ...
- C# 获取当年的周六周日
public void GetWMDay() { List<string> list = new List<string>(); "; DateTime counYe ...
- day11-函数对象、名称空间和作用域
目录 函数对象 函数的嵌套 名称空间和作用域 内置名称空间 全局名称空间 局部名称空间 作用域 全局作用域 局部作用域 global和nonlocal 函数对象 在Python中,一切皆对象,函数也是 ...
- Linux内核系统调用处理过程
原创作品转载请注明出处 + https://github.com/mengning/linuxkernel/ 学号末三位:168 下载并编译Linux5.0 xz -d linux-.tar.xz . ...
- SweetAlert弹出框
以前也用过,那个时候没有写过,突然看见了,就写上了. 网址:http://mishengqiang.com/sweetalert2/ swal({ title: '确定删除吗?', text: '你将 ...
- CVPR2016 Paper list
CVPR2016 Paper list ORAL SESSIONImage Captioning and Question Answering Monday, June 27th, 9:00AM - ...
- Win32 线程同步
Win32 线程同步 ## Win32线程同步 ### 1. 原子锁 ### 2. 临界区 {全局变量} CRITICAL_SECTION CS = {0}; // 定义并初始化临界区结构体变量 {线 ...
- 如何在redhat 7上安装VNC服务器
平时我们基本上都是用xshell或者用putty远程我们的linux服务器,如果我们的linux服务器安装了图型化界面那我们又该如何远程使用我们的图形化界面呢?下面我们用vnc来实现远程我们的linu ...
- Django cookie、session使用
一.cookie Cookie是key-value结构,类似于一个python中的字典.随着服务器端的响应发送给客户端浏览器.然后客户端浏览器会把Cookie保存起来,当下一次再访问服务器时把Cook ...