Creating a Java Spark project with Maven and junit
This blog post shows how to write some Spark code with the Java API and run a simple test.
The code snippets in this post are from this GitHub repo.
Project setup
Start by creating a pom.xml
file for Maven.
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>mrpowers</groupId>
<artifactId>JavaSpark</artifactId>
<version>1.0-SNAPSHOT</version>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.7.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.4.0</version>
</dependency>
</dependencies>
</project>
This build file adds Spark SQL as a dependency and specifies a Maven version that’ll support some necessary Java language features for creating DataFrames.
Write some code
Let’s create a Transformations
class with a myCounter
method that returns the number of rows in a DataFrame. myCounter
would not ever be useful in a real project, but it’s best to get started with a simple example.
Create the src/main/java/mrpowers/javaspark/Transformations.java
file.
package mrpowers.javaspark;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
public class Transformations {
public long myCounter(Dataset<Row> df){
return df.count();
}
}
The Transformations
class lives in the mrpowers.javaspark
package. Namespacing is important to prevent class name collisions (we don’t want Java to get our Transformations
class confused with another library that has a class with the same name).
We import the Dataset
and Row
classes from Spark so they can be accessed in the myCounter
function.
We could have imported all of the Spark SQL code, including Dataset
and Row
, with a single wildcard import: import org.apache.spark.sql.*
Wildcard imports make it harder to identify where classes are defined and it’s generally best to avoid them.
Write a test
Let’s use junit to test the myCounter
function.
Add junit as a dependency in the pom.xml
file.
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13-beta-1</version>
</dependency>
Our test will create a DataFrame with two rows and verify that the myCounter
function returns the integer 2
when it’s passed our DataFrame as an input.
Here’s how our test logic at a high level.
// create a DataFrame called df
Transformations transformations = new Transformations();
long result = transformations.myCounter(df);
assertEquals(2, result);
The junit
assertEquals
function is where we actually make our assertion to verify that the actual output and expected output match.
Let’s take a look at the whole test file in src/test/java/mrpowers/javaspark/TransformationsTest.java
.
Brace yourself for some verbose code!
package mrpowers.javaspark;
import org.junit.Test;
import static org.junit.Assert.*;
import java.util.List;
import java.util.ArrayList;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class TransformationsTest implements SparkSessionTestWrapper {
@Test
public void testMyCounter() {
List<String[]> stringAsList = new ArrayList<>();
stringAsList.add(new String[] { "bar1.1", "bar2.1" });
stringAsList.add(new String[] { "bar1.2", "bar2.2" });
JavaSparkContext sparkContext = new JavaSparkContext(spark.sparkContext());
JavaRDD<Row> rowRDD = sparkContext
.parallelize(stringAsList)
.map((String[] row) -> RowFactory.create(row));
// Create schema
StructType schema = DataTypes
.createStructType(new StructField[] {
DataTypes.createStructField("foe1", DataTypes.StringType, false),
DataTypes.createStructField("foe2", DataTypes.StringType, false)
});
Dataset<Row> df = spark.sqlContext().createDataFrame(rowRDD, schema).toDF();
Transformations transformations = new Transformations();
long result = transformations.myCounter(df);
assertEquals(2, result);
}
}
Let’s create a SparkSessionTestWrapper
interface to access the Spark session in our test. The SparkSession is defined in an interface so multiple test files can use the same SparkSession.
package mrpowers.javaspark;
import org.apache.spark.sql.SparkSession;
public interface SparkSessionTestWrapper {
SparkSession spark = SparkSession
.builder()
.appName("Build a DataFrame from Scratch")
.master("local[*]")
.getOrCreate();
}
Run the tests with the mvn test
command.
Next steps
This tutorial gives us a great foundation to explore more features that all Java Spark programmers need to master. Here are the next steps:
- Building JAR files with Maven (similar to building JAR files with SBT)
- Chaining custom transformations (we already know how to do this with the Scala API and with PySpark)
- Making DataFrame comparisons in the test suite with spark-fast-tests
- Using spark-daria in application code
P.S. This is the first Java code I’ve ever written. Please post a comment or email me if you have any suggestions on how to make this code better.