DataSet Operations – Pivoting Data

Pivoting data is the process of re-arranging the data from rows into columns based on one or more columns. It is also sometimes called re-shaping data.

Pivot in Spark is quite simple and can be accomplished via the pivot method. More importantly, before you pivot data make sure it has the right columns which are required. If there are columns you donot need in the pivot make sure they are not there in the dataset.

To enable pivoting in spark requires the following.

  • groupBy – Is passed columns which are used for aggregating and are not pivoted.
  • pivot – Column(s) whose values are converted into columns
  • agg – Aggregation method which is applied during pivot

Pivoting a Single Column

Let’s take a look at an example. See Below

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

object SparkDataSetsDateTime {
  def main(args: Array[String]): Unit = {

    //Step 1 - Create a spark session
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("Spark DataSet DateTime")
      .getOrCreate

    //Step 2 - Create a schema for the file to be read
    val ordersSchema = StructType(
      List(
        StructField("orderkey", IntegerType, true),
        StructField("custkey", StringType, true),
        StructField("orderstatus", StringType, true),
        StructField("totalprice", DoubleType, true),
        StructField("orderdate", DateType, true),
        StructField("order_priority", StringType, true),
        StructField("clerk", StringType, true),
        StructField("ship_priority", StringType, true),
        StructField("comment", StringType, true)
      )
    )

    //Step 3 - Read the CSV file - with Options
    val ds = spark.read
      .option("header",false)
      .option("delimiter","|")
      .schema(ordersSchema)
      .csv("orders.tbl")

    //Step 4 - Use order key, order date, order status and total price to create a new dataset
    import spark.implicits._
    val baseDS = ds.select("orderdate", "orderstatus","totalprice")
      .withColumn("orderYear",year($"orderdate"))
      .select("orderstatus","orderYear","totalprice")

    //Step 5 - Pivot the data
    val pivotDS = baseDS.groupBy("orderstatus")
      .pivot("orderYear")
      .agg(round(avg("totalprice"),2) as "avg_value")

    //Step 6 - Show the dataset in the log
    pivotDS.show(truncate = false)
  }
}

Let’s analyse the above example

  • Step 1 – Creates a spark session
  • Step 2 – Creates a custom schema
  • Step 3 – Reads the file and creates a DataSet based on it.
  • Step 4 – Creates the new dataset which has only three fields
  • Step 5 – Pivots data. The groupBy method is passed a column(order status) – it is not used for pivoting. pivot method is passed a column(orderYear) whose values will be pivoted into columns. The agg method is passed an aggregation to average total price.
  • Step 6 – Shows the pivoted data in the log

The relevant output is shown below

Multiple Columns can also be used for pivoting data. Instead of one column multiple columns may be mentioned in groupBy, agg or pivot column depending upon the use case.

Multiple Columns can also be used for pivoting data. Instead of one column multiple columns may be mentioned in groupBy, agg or pivot column depending upon the use case.

Single Column Pivot – Multiple Aggregates

Let’s see an example with more than one aggregate in the agg method. The next example is similar to the previous example but in addition to average, it also has a minimum price.

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

object SparkDataSetsDateTime {
  def main(args: Array[String]): Unit = {

    //Step 1 - Create a spark session
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("Spark DataSet DateTime")
      .getOrCreate

    //Step 2 - Create a schema for the file to be read
    val ordersSchema = StructType(
      List(
        StructField("orderkey", IntegerType, true),
        StructField("custkey", StringType, true),
        StructField("orderstatus", StringType, true),
        StructField("totalprice", DoubleType, true),
        StructField("orderdate", DateType, true),
        StructField("order_priority", StringType, true),
        StructField("clerk", StringType, true),
        StructField("ship_priority", StringType, true),
        StructField("comment", StringType, true)
      )
    )

    //Step 3 - Read the CSV file - with Options
    val ds = spark.read
      .option("header",false)
      .option("delimiter","|")
      .schema(ordersSchema)
      .csv("orders.tbl")

    //Step 4 - Select order key, order date, order status and total price to create a new dataset
    import spark.implicits._
    val baseDS = ds.select("orderdate", "orderstatus","totalprice")
      .withColumn("orderYear",year($"orderdate"))
      .select("orderstatus","orderYear","totalprice")

    //Step 5 - Pivot the data
    val pivotDS = baseDS.groupBy("orderstatus")
      .pivot("orderYear")
      .agg(round(avg("totalprice"),2) as "avg_value",min("totalprice") as "min_price")

    //Step 6 - Show the dataset in the log
    pivotDS.show(truncate = false)

  }
}

The relevant portion of the log is below

Click to Zoom

Multiple Column Pivot

Multiple columns in pivot method are not possible but there is a workaround. We can concatenate the values of columns and hence enable pivoting using more than one columns.

See the example below

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

object SparkDataSetsDateTime {
  def main(args: Array[String]): Unit = {

    //Step 1 - Create a spark session
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("Spark DataSet DateTime")
      .getOrCreate

    //Step 2 - Create a schema for the file to be read
    val ordersSchema = StructType(
      List(
        StructField("orderkey", IntegerType, true),
        StructField("custkey", StringType, true),
        StructField("orderstatus", StringType, true),
        StructField("totalprice", DoubleType, true),
        StructField("orderdate", DateType, true),
        StructField("order_priority", StringType, true),
        StructField("clerk", StringType, true),
        StructField("ship_priority", StringType, true),
        StructField("comment", StringType, true)
      )
    )

    //Step 3 - Read the CSV file - with Options
    val ds = spark.read
      .option("header",false)
      .option("delimiter","|")
      .schema(ordersSchema)
      .csv("orders.tbl")

    //Step 4 - Select order key, order date, order status and total price to create a new dataset
    import spark.implicits._
    val baseDS = ds.select("orderdate", "orderstatus","totalprice")
      .withColumn("orderYear",year($"orderdate"))
      .withColumn("orderMonth",month($"orderdate"))
      .select("orderstatus","orderYear","orderMonth","totalprice")

    //Step 5 - Pivot the data
    val pivotDS = baseDS.groupBy("orderMonth")
      .pivot(concat($"orderstatus",lit("_"),$"orderYear"))
      .agg(min("totalprice") as "min_price")
      .orderBy($"orderMonth")

    //Step 6 - Show the dataset in the log
    pivotDS.show(truncate = false)

  }
}

The relevant part of the log is below

Observe above how two values are concatenated to create a column. This way multiple columns can be pivoted and interesting datasets created. You can enhance this logic to include multiple aggregates to a multiple column pivot by the process described in this post.

Hope you have found this blog entry interesting and useful.

Till next time….Byeeeee!

Leave a Comment