Reconciling Spark APIs for Scala

Scala 3 gives you the tools to design the perfect Spark API. We proved it by creating the open source library Iskra.

Spark API Scala 3

Spark provides Scala programmers with more than one API for building big data pipelines. However, each of them requires some sacrifice – worse performance, additional boilerplate or lack of type safety. We propose a new Spark API for Scala 3 that solves all of these problems.

Spark API for Scala: Same goal – different paths

Scala is a statically typed language. However, it’s up to programmers how much type information they preserve at compile time. E.g. given an integer and a string, we could store them in a tuple of type (Int, String) or in a more general collection like Seq[Any]. Following this philosophy, and because of some historical reasons, Spark offers two flavours of high level APIs for Scala: 

  • a precisely typed one based on Datasets
  • a loosely typed one based on DataFrames or SQL queries. As they have many common strengths and weaknesses, we will ignore SQL and focus on DataFrames further on.

Unfortunately, neither of these options is perfect.

Let’s say we have some data representing measurements of temperature and air pressure from weather stations. We can model a single measurement as a case class like this:

​​case class Measurement(
  stationId: Long,
  temperature: Int /* in °C */,
  pressure: Int /* in hPa */,
  timestamp: Long

We would like to find the IDs and average air pressure for all stations with the amplitude of temperature less than 20°C. 

Let’s try to solve our problem using each of the approaches mentioned above.

Spark Dataset API: Using idiomatic Scala

We’ll go with Datasets first. Our solution will look very much like operating on standard Scala collections. This API lets us use ordinary Scala functions to manipulate our data model specified by tuples or case classes. 

We’ll skip the boilerplate of setting up a Spark application and assume measurements is a Dataset[Measurement] containing our data. We can now implement our core logic as shown below:

    .mapGroups { (stationId, measurementss) =>
      val temperatures =
      val pressures =
        pressures.sum.toDouble / pressures.length
    .filter(entry => entry._3 - entry._2 < 20)
    .map(entry => (entry._1, entry._4))

Using tuples may seem convenient at first, but they will make our codebase hard to read once it grows bigger. Alternatively, we could replace tuples with case classes, as shown below. However, having to define a case class for every intermediate step of our computations might quickly become a burden as well.

case class AggregatedMeasurement(
  stationId: Long,
  minTemperature: Int,
  maxTemperature: Int,
  avgPressure: Double

/* … */

    .mapGroups { (stationId, measurementss) =>
      val temperatures =
      val pressures =
        stationId = stationId,
        minTemperature = temperatures.min,
        maxTemperature = temperatures.max,
        avgPressure = pressures.sum.toDouble / pressures.length
    .filter(aggregated => aggregated.maxTemperature - aggregated.minTemperature < 20)
    .map(aggregated => (aggregated.stationId, aggregated.avgPressure))

In this approach, the compiler knows the exact type of our data model after each transformation. So it can verify the correctness of our program, at least to some extent. E.g. it will raise an error, if we try to refer to a column which doesn’t exist in our model or its type doesn’t make sense in a given context. We’ll even get code completions for the names of columns, which could help us eliminate many potential errors, even before we compile our entire codebase. 

Despite these amenities, our application might still surprise us with a runtime error. This would happen e.g. if we defined our helper case class inside a method instead of an object or a package. Doing so would cause problems with serialisation that wouldn’t get detected until we run our program. 

A major problem with the Dataset API that also has to be mentioned is its performance. Because we can execute arbitrary Scala code in the bodies of our lambdas, Spark treats them as blackbox and cannot perform many of its optimizations.

Spark DataFrame API: Pretending to be Python

DataFrames, in contrast to Datasets, are not parameterised with the type of data they contain, so the compiler knows nothing about it. This is similar in design to the API that Spark offers for Python, where we refer to columns by their names as strings.

This might bring some flexibility in certain cases, as we’re free from the tuple vs case class dilemma. We could also compute names of columns dynamically. 

However, in most other cases it’s rather annoying. What’s even more important, this is dangerous. As our data schema is only known at runtime, we typically won’t learn about many problems in our code until we deploy and run the entire application (or at least run our tests).

Our example rewritten to the DataFrame based API could look like this:

    .where($"maxTemperature" - $"minTemperture" < lit(20))
    .select($"stationId", $"avgPressure")

As you can see now, the names of the methods are more like SQL keywords rather than something one might be familiar with from Scala’s standard library. If you take a closer look at the snippet you might even spot that it’s actually going to crash at runtime because of the typo in minTemperture. And even if we fix that once, something might go wrong again if at some point later on we decided to refactor our code by renaming one of the columns but forgot to do it in some places.

We gave up most type safety, but at least we got something in exchange. Because we are restricted to using only column transformations defined inside Spark, its optimization engine can heavily speed up the computations. If only our program doesn’t crash at runtime.

Designing a perfect Spark API for Scala:

You could ask yourself the question: Can we design a better API for Spark that doesn’t force users to choose between type safety, convenience of use and efficiency? 

Yes, we can! 

Scala 3 provides all the tools required to achieve that. Let’s take the DataFrame approach as a starting point and try to improve it.

First, let’s stop referring to columns by their stringy names:

    .where($.maxTemperature - $.minTemperature < lit(20))
    .select($.stationId, $.avgPressure)

The only thing that changed in our code so far is that we’ve replaced all $"foo"⁣-like references with $.foo. Our snippet looks now more like vanilla Scala syntax, where one refers to nested parts of data structures using a dot operator. We could make this compile without much hassle already in Scala 2 by using the Dynamic marker-trait.

import scala.language.dynamics
import org.apache.spark.sql.functions.col

object $ extends Dynamic {
  def selectDynamic(name: String) = col(name)

This might seem like magic, but it’s actually rather straightforward. Thanks to this trick, every expression like $.foo gets rewritten by the compiler into $.selectDynamic("foo"), given that $ has no statically known member called foo.

However, more convenient column access by itself isn’t much of a game changer, since we still get feedback about errors only at runtime. But it turns out that in Scala 3 we can overcome this problem by using Selectable instead of Dynamic.

import org.apache.spark.sql.{ Column => UntypedColumn }
import org.apache.spark.sql.functions.col

class Column[T](val untyped: UntypedColumn) extends AnyVal

trait RowModel extends Selectable {
  def selectDynamic(name: String) = Column(col(name))

def $: RowModel { /* ... */ } = new RowModel { /* .. */ }

Now, the type of $ is RowModel with some type refinement. Let’s say it was RowModel { def foo: Column[Int] }. Then, $.foo would turn into $.selectDynamic("foo").asInstanceOf[Column[Int]]. The desugaring contains an extra type cast, but it’s safe. The compiler took the type of foo from the refinement. If foo was not defined there, the compilation would fail.

Context does matter

The issue that we still need to solve is that the type refinement of RowModel has to be different depending on the circumstances in which we refer to $. These include the shape of our initial data model and the stage of the transformation pipeline we’re currently in. 

Say, selecting avgPressure should be invalid before it gets computed inside the agg block. Similarly, we shouldn’t be allowed to refer to pressure of a single measurement after the aggregation. So, how can we get the compiler to trace the correct type of $ at each step of our computations?

First, we need a refined type that represents the initial structure of our data frame. As Measurement is a case class, we can use Scala 3’s metaprogramming capabilities to construct it. We won’t go deeper into the implementation details here, but what we would like to get as the result is something like:

RowModel {
  def stationId: Column[Long]
  def temperature: Column[Int]
  def pressure: Column[Int],
  def timestamp: Column[Long]

Later on, when we perform a transformation such as selection or aggregation, we pass on a block of code returning a column or a tuple of columns which would determine the shape of our data row in the next step of computations. Inside this block, $ needs to have the right type, which is context specific. So why don’t we use context functions, another Scala 3 feature, to achieve that?

Even if you’ve never heard of context functions before, you might have come across a Scala 2 programming pattern like:

def bar(fun: Context => Int) = ???
def baz(implicit context: Context): Int = ???

bar { implicit context =>

In the code above, fun is a function from Context to Int and the implicit keyword before the argument of a lambda makes it available to the implicit search inside the lambda’s body. Context functions are defined almost like ordinary functions, but with ?=> instead of =>. Making our function contextual lets us get rid of boilerplate caused by implicit context => at the beginning of the closure. Effectively, our auxiliary snippet gets simplified to:

def bar(fun: Context ?=> Int) = ???
def baz(using context: Context): Int = ???

bar {

If you aren’t fully familiar with Scala 3 syntax: The keyword using is a replacement for implicit when declaring an implicit parameter.

Let’s get back to our Spark API. We’ll treat RowModel with its precisely refined type as our implicitly passed context. Then we’ll use the $ method to capture it.

def $(using rowModel: RowModel): rowModel.type = rowModel

Note that the return type is rowModel.type instead of just RowModel. This lets us preserve the precise type with the refinements. That gives us a guarantee that every reference to a column in the form of $.foo is valid in the given context. We also know at compile time the exact types of data in each column. Going further, we could use this information to assure that operations on columns are also sensical, e.g. that the condition inside .where(...) indeed represents a boolean, or that we don’t attempt to divide a number by a string.

Making our dreams come true

Now you already know the most important concepts and syntactic patterns that you could use to implement a type safe wrapper around the loosely typed API Spark provided for DataFrames. So why don’t you try it yourself? This might be a good exercise, but let us cool your enthusiasm down for a moment. The actual type system used by Spark internally turns out to be not so easy to model statically. Also, the amount of available operations one can perform on data frames and columns is huge. It would require a lot of work to cover them all. But we still believe the goal is reachable, so we started a common initiative in the form of an open source library called Iskra.

The intent of the project was to provide a Spark API for Scala 3 that:

  • is type safe, providing meaningful compilation errors
  • avoids boilerplate
  • is intuitive to use for people already familiar with Spark
  • works well with IDEs, e.g. providing code completions for methods and names of columns
  • is efficient, taking advantage of all optimizations Spark offers for DataFrame and SQL based APIs
  • is extensible, giving library users the possibility to easily define their own typed wrappers for methods from the API not yet covered by the library

You can try it out right away! Here’s the complete solution to our issue with weather stations’ measurements:

//> using scala "3.2.0"
//> using lib "org.virtuslab::iskra:0.0.2"

import org.virtuslab.iskra.api.*

case class Measurement(
  stationId: Long,
  temperature: Int /* in °C */,
  pressure: Int /* in hPa */,
  timestamp: Long

@main def run() =
  given spark: SparkSession = SparkSession.builder()

  val measurements = Seq(
    Measurement(1, 10, 1020, 1641399107),
    Measurement(2, -5, 1036, 1647015112),
    Measurement(1, 19, 996, 1649175104),
    Measurement(2, 25, 1015, 1657030348),
    /* more data … */

  import functions.{avg, min, max, lit}

    .where($.maxTemperature - $.minTemperature < lit(20))
    .select($.stationId, $.avgPressure)

You can run the snippet using scala-cli. This is how:

  1. Follow these instructions to install scala-cli
  2. Save the code to a file (let’s call it SparkWeather.scala
  3. Run scala-cli --jvm temurin:11 SparkWeather.scala from the command line

If you use VS Code with Metals as your IDE, you can also see how code completions work. To do so:

  1. Open the directory containing SparkWeather.scala in VS Code
  2. Run scala-cli setup-ide SparkWeather.scala from the command line
  3. Click Connect to build server in Metals’ sidebar menu
  4. Start typing

Try Iskra out and share your feelings about it with us! Contributions are welcome as well. Let’s make Spark in Scala better together.

Written by

Michał Pałka
Michał Pałka Senior Scala Developer Sep 7, 2022