Load data from CSV file to train model
Based on learning use trained model to predict the output.
Data used here is from https://www.kaggle.com/c/titanic
For running Spark from Eclipse it is required that we set following VM argument.
-Xmx512m
Some important tips are from Kaggle forums.
Based on learning use trained model to predict the output.
Data used here is from https://www.kaggle.com/c/titanic
For running Spark from Eclipse it is required that we set following VM argument.
-Xmx512m
Following is source code.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.spark.titanic.main; | |
import java.util.HashMap; | |
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.apache.spark.api.java.function.Function; | |
import org.apache.spark.api.java.function.VoidFunction; | |
import org.apache.spark.mllib.linalg.Vectors; | |
import org.apache.spark.mllib.regression.LabeledPoint; | |
import org.apache.spark.mllib.tree.RandomForest; | |
import org.apache.spark.mllib.tree.model.RandomForestModel; | |
import org.apache.spark.sql.DataFrame; | |
import org.apache.spark.sql.SQLContext; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import scala.Tuple2; | |
import com.spark.titanic.bean.PessangerBean; | |
import com.spark.titanic.bean.PessangerBean.Embark; | |
import com.spark.titanic.bean.PessangerBean.Sex; | |
public class TitanicMain { | |
private static final Logger logger = LoggerFactory | |
.getLogger(TitanicMain.class); | |
private static JavaSparkContext sc; | |
private static SQLContext sqlCtx; | |
public static void main(String[] args) { | |
System.setProperty("hadoop.home.dir", | |
"D:/sw/hadoop/hadoop-common-2.2.0-bin-master/"); | |
// Define a configuration to use to interact with Spark | |
SparkConf conf = new SparkConf().setMaster("local").setAppName( | |
"Titanic Survival Analytics App"); | |
// Create a Java version of the Spark Context from the configuration | |
sc = new JavaSparkContext(conf); | |
loadTrainingData(); | |
} | |
private static void loadTrainingData() { | |
JavaRDD<LabeledPoint> trainingRDD = sc | |
.textFile("file://D:/tech/data/titanic/train.csv") | |
.filter(new Function<String, Boolean>() { | |
@Override | |
public Boolean call(String line) throws Exception { | |
return line.charAt(0) != 'P'; | |
} | |
}).map(new Function<String, LabeledPoint>() { | |
@Override | |
public LabeledPoint call(String line) throws Exception { | |
String[] fields = line.split( | |
",(?=([^\"]*\"[^\"]*\")*[^\"]*$)", -1); | |
// logger.info("{} {} {}",fields[1],fields[2],fields[5]); | |
LabeledPoint point = new LabeledPoint(Double | |
.valueOf(fields[1]), Vectors.dense(Double | |
.valueOf(fields[2]), Double | |
.valueOf(fields[5] == null | |
|| "".equals(fields[5].trim()) ? "0" | |
: fields[5]), | |
Double.valueOf(fields[6]), Double | |
.valueOf(fields[7]), Double | |
.valueOf(fields[9]), "male" | |
.equalsIgnoreCase(fields[4]) ? 1d : 0d)); | |
return point; | |
} | |
}); | |
Integer numClasses = 2; | |
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | |
Integer numTrees = 3; // Use more in practice. | |
String featureSubsetStrategy = "auto"; // Let the algorithm choose. | |
String impurity = "gini"; | |
Integer maxDepth = 5; | |
Integer maxBins = 32; | |
Integer seed = 12345; | |
double split[] = { 0.7, 0.3 }; | |
logger.info("No of splits {}", split.length); | |
JavaRDD<LabeledPoint> splits[] = trainingRDD.randomSplit(split); | |
logger.info("Count in splits is {} {}", splits[0].count(), | |
splits[1].count()); | |
final RandomForestModel model = RandomForest.trainClassifier( | |
trainingRDD, numClasses, categoricalFeaturesInfo, numTrees, | |
featureSubsetStrategy, impurity, maxDepth, maxBins, seed); | |
JavaRDD<LabeledPoint> testData = sc | |
.textFile("file://D:/tech/data/titanic/test.csv") | |
.filter(new Function<String, Boolean>() { | |
@Override | |
public Boolean call(String line) throws Exception { | |
return line.charAt(0) != 'P'; | |
} | |
}).map(new Function<String, LabeledPoint>() { | |
@Override | |
public LabeledPoint call(String line) throws Exception { | |
String[] fields = line.split( | |
",(?=([^\"]*\"[^\"]*\")*[^\"]*$)", -1); | |
// logger.info("{} {} {}",fields[1],fields[2],fields[5]); | |
LabeledPoint point = new LabeledPoint(Double | |
.valueOf(fields[0]), Vectors.dense(Double | |
.valueOf(fields[1]), // pClass | |
Double.valueOf(fields[4] == null | |
|| "".equals(fields[5].trim()) ? "0" | |
: fields[5]),// Age | |
Double.valueOf(fields[5]),// SibSp | |
Double.valueOf(fields[6]),// pArch | |
Double.valueOf(fields[8] == null | |
|| "".equals(fields[8].trim()) ? "0" | |
: fields[8]),// Fare | |
"male".equalsIgnoreCase(fields[4]) ? 1d : 0d)); | |
return point; | |
} | |
}); | |
JavaRDD<Tuple2<Double, Double>> testDataPred = testData | |
.map(new Function<LabeledPoint, Tuple2<Double, Double>>() { | |
@Override | |
public Tuple2<Double, Double> call(LabeledPoint point) | |
throws Exception { | |
Tuple2<Double, Double> tuple = new Tuple2<Double, Double>( | |
point.label(), model.predict(point.features())); | |
return tuple; | |
} | |
}); | |
testDataPred.foreach(new VoidFunction<Tuple2<Double, Double>>() { | |
@Override | |
public void call(Tuple2<Double, Double> arg0) throws Exception { | |
logger.info("{}\t{}", arg0._1, arg0._2); | |
} | |
}); | |
} | |
} |
Note that if Apache Hadoop is not installed on local machine then just download binaries and set system property hadoop.home.dir. If you are running stand alone code and do not want to hardcode set the proprety using -D option.