MutableList And The Short Path To A StackOverflowError

When working with collections in Scala, or any other high level programming language, one does not always stop to think about the underlying implementation of the collection. We want to find the right tool for the right job, and we want to do is as fast as we can. The Scala collection library brings us a wealth of options, be them mutable or immutable, and it can sometimes become confusing which one we should choose.

An interesting case came about when a colleague of mine at work was running into a weird StackOverflowError when running a Spark job. We were seeing a long StackTrace both when trying to serialize with both Kryo and Java serializers.

The program looked rather innocent. Here is a close reproduce:

object Foo {
  sealed trait A
  case class B(i: Int, s: String)
  case class C(i: Int, x: Long)
  case class D(i: Int)

  case class Result(list: mutable.MutableList[Int])

  def doStuff(a: Seq[A]): Result = {
    val mutable = new collection.mutable.MutableList[Int]

    a.foreach {
      case B(i, _) => mutable += i
      case C(i, _) => mutable += i
      case D(i) => mutable += i


The code that was running was iterating over a Seq[A], parsing them and appending them to a MutableList[Int]. The doStuff method was part of a Spark operation, which was consuming records from Kafka, parsing them and them handing them off for some more stateful computation, which looked like:

object SparkJob {
  def main(args: Array[String]): Unit = {
    val dStream = KafkaUtil.createDStream(...)
     .mapPartitions(it => Foo.doStuff(it.asSeq))
     .foreachRDD(x => // Stuff)

One important thing to note is that this issue suddenly started appearing in their QA environment. There wasn’t much change to the code, all that was done was extending some classes which were inheriting A, and that more data was starting to come from Kafka, but there weren’t any major changes to the code base.

This was the exception we were seeing (this one for JavaSerializer):

Exception in thread “main” java.lang.StackOverflowError at$WeakClassKey.( at at at at at at at at at at at at at at at at at at at at at

The class hierarchy was a bit beefier than my simplified example. It consisted by itself of a view nested case classes, but none of them were hiding a recursive data structure. This was weird.

We started a divide and conquer approach, eliminating piece by piece of the code, trying to figure out which class was causing the trouble. Then, by mere chance, I looked at the MutableList and told him: “Lets try running this with an ArrayBuffer instead, and see what happens”. To our surprise, the Stackoverflow was gone and not reproducing anymore.

So what’s the deal with MutableList[A]?

After the long intro, let’s get down to the gory detail, what’s up with MutableList? Well, if we peek under the hood and look at the implementation, MutableList[T] is a simple LinkedList[T] with a first and last elements. It has both a method head of type T, and a method tail of type MutableList[A] (similar to List[A]):

class MutableList[A]
extends AbstractSeq[A]
   with LinearSeq[A]
   with LinearSeqOptimized[A, MutableList[A]]
   with GenericTraversableTemplate[A, MutableList]
   with Builder[A, MutableList[A]]
   with Serializable
  override def companion: GenericCompanion[MutableList] = MutableList
  override protected[this] def newBuilder: Builder[A, MutableList[A]] = new MutableList[A]

  protected var first0: LinkedList[A] = new LinkedList[A]
  protected var last0: LinkedList[A] = first0
  protected var len: Int = 0

  /** Returns the first element in this list
  override def head: A = if (nonEmpty) first0.head else throw new NoSuchElementException

  /** Returns the rest of this list
  override def tail: MutableList[A] = {
    val tl = new MutableList[A]

  // Shortened for brevity

When we attempt to serialize a recursive data structure such as LinkedList[A] or even deeply nested structures, be it Kryo or Java serialization, we need to traverse them all the way down. Since LinkedList[A] holds an element of type A and a pointer to the next element of LinkedList[A], we need to go deep down. Each time the serializer encounters a new class of type LinkedList[A] (which is the next pointer), it will open up a new stack frame and begin to iterate it to find the next element in line to be written. If we have many such elements that cause us to open a new frame, we eventually blow up.

I tried playing around to see how many elements we can fit in a MutableList[Int] before it explodes. I ran this test on Scala 2.12.0 and Java 1.8.0_91 in a x64 process (which should have a 1MB stack AFAIK), it took exactly 1335 elements to make this blow up:

object StackoverflowTest {
  def main(args: Array[String]): Unit = {
    val mutable = new collection.mutable.MutableList[Int]
    (1 to 1335).foreach(x => mutable += x)
    val ous = new ObjectOutputStream(new ByteArrayOutputStream())

But wait, isn’t List[A] in Scala also defined as a linked list?

How can it be that when using the equivalent with a List[A] in Scala, this doesn’t blow up?

object ListSerializationTest {
  def main(args: Array[String]): Unit = {
    val mutable = List.range(0, 1500)
    val ous = new ObjectOutputStream(new ByteArrayOutputStream())

And well, it doesn’t. Turns that List[A] has some secret sauce!

writeObject and readObject:

Every object that uses Java serialization can provide a private writeObject and readObject pair which lay out exactly how to serialize the object. Scala uses a custom class called SerializationProxy to provide an iterative version of serialization/deserialization for List[A]:

private class SerializationProxy[A](@transient private var orig: List[A]) extends Serializable {

  private def writeObject(out: ObjectOutputStream) {
    var xs: List[A] = orig
    while (!xs.isEmpty) {
      xs = xs.tail

  // Java serialization calls this before readResolve during de-serialization.
  // Read the whole list and store it in `orig`.
  private def readObject(in: ObjectInputStream) {
    val builder = List.newBuilder[A]
    while (true) in.readObject match {
      case ListSerializeEnd =>
        orig = builder.result()
      case a =>
        builder += a.asInstanceOf[A]

  // Provide the result stored in `orig` for Java serialization
  private def readResolve(): AnyRef = orig

At runtime, the java serializer will reflect over the class and try to find these methods and use them to serialize or deserialize the object. This is exactly the reason why it works and we don’t blow up with a StackOverflowError in our face.


The first and most obvious thing, always consider the data structures you’re using and make sure they’re the right one for the job! Although the language provides us with high level abstractions of collections, strive to know what’s going on under the covers and make sure the collection you’re using won’t come back to bite you. If for performance reasons you’re looking to use a mutable collection, think about encapsulating the use of ArrayBuffer or ListBuffer internally and exposing an immutable Array[A] and List[A] respectively that you create once you’re done filling them up.

An generally remember that MutableList[A] is internally using a LinkedList[A] and there isn’t a custom writeObject and readObject pair implemented for them in the Scala collection, and watch out if you need to be transfering it over the wire.

These kind of bugs are hard to discover and we’re lucky that we found about them early and not in production.

The Case Of The Immutable Map and Object Who Forgot To Override HashCode

Disclaimer: What we’re about to look at is an implementation detail valid for the current point in time, and is valid for Scala 2.10.x and 2.11.x (not sure about previous versions). This is subject to change at any given time, and you should definitely not rely these side effect when writing code.

Consider the following example. Given a class Foo:

class Foo(val value: Int) {
    override def equals(obj: scala.Any): Boolean = obj match {
      case other: Foo => value == other.value
      case _ => false

What would you expect the following to print?

import scala.collection.mutable

val immutableMap = Map(new Foo(1) -> 1)
val mutableMap = mutable.Map(new Foo(1) -> 1)

immutableMap.getOrElse(new Foo(1), -1)
mutableMap.getOrElse(new Foo(1), -1)

If you’re thinking: “Well, he didn’t override Object.hashCode, so both immutable and mutable Maps aren’t going to find the value. This is should fallback to -1, you might be surprised:

scala> immutableMap.getOrElse(new Foo(1), -1)
res3: Int = 1

scala> mutableMap.getOrElse(new Foo(1), -1)
res4: Int = -1

Hmmm.. What?

How is immutable.Map retrieving 1? and why is mutable.Map outputting the expected result (-1)? Maybe we’re just lucky and both objects have the same hash code?

scala> val first = new Foo(1)
first: Foo = Foo@687b0ddc

scala> val second = new Foo(1)
second: Foo = Foo@186481d4

scala> first.hashCode == second.hashCode
res6: Boolean = false

scala> first.hashCode
res7: Int = 1752894940

scala> second.hashCode
res8: Int = 409240020

Doesn’t seem so.

This is precisely the point of this post. We’ll see how Scala has a special implementation for immutable.Map and what side effects that might have.

The ground rules for custom objects as Map keys

To all accustomed with the Map data structure know that any object used as a key should obey the following rules:

  1. Override Object.equals - Equality, if not explicitly overridden, is reference equality. We desire such that not only the same instances be equal, but also two objects which follow our custom equality semantics, that their value fields be equal.
  2. Override Object.hashCode - Any two objects, if equal, should yield the same hash code, but not vice versa (see Pigeonhole principle). This is extremely important for objects used as keys of a Map, since (most) implementations relay on the hash code of the key to determine where the value will be stored. That same hash code will be used later when one requests a lookup by a given key.
  3. Hash code should be generated from immutable fields - It is common to use the objects fields as part of the hash code algorithm. If our value field was mutable, one could mutate it at runtime causing a different hash code to be generated, and a side effect of that would be not being able to retrieve it from the Map.

But our custom object doesn’t exactly follow these rules. It does override equals, but not hashCode.

This is where things get interesting.

The secret sauce of immutable.Map

Scala has a custom implementation for up to 4 key value pairs (Map1, …, Map4). These custom implementations don’t rely on the implementation of hashcode to find the entry in the Map, they simply store the key value pairs as fields. and do an equality check on the key:

class Map1[A, +B](key1: A, value1: B) extends AbstractMap[A, B] with Map[A, B] with Serializable {
    override def size = 1
    def get(key: A): Option[B] =
      if (key == key1) Some(value1) else None

You see that the key is directly compared to key1, and this is exactly why immutable.Map retrieves the 1, since our equals implementation is in order.

If we did some REPL tests for cases which are below and above 4 elements, we’d see inconsistent results that are caused by this implementation detail:

scala> val upToFourMap = Map(new Foo(1) -> 1, new Foo(2) -> 2, new Foo(3) -> 3, new Foo(4) -> 4)

scala> upToFourMap.getOrElse(new Foo(1), -1)
res2: Int = 1

scala> val upToFiveMap = Map(new Foo(1) -> 1, new Foo(2) -> 2, new Foo(3) -> 3, new Foo(4) -> 4, new Foo(5) -> 5)

scala> upToFiveMap.getOrElse(new Foo(1), -1)
res1: Int = -1

Once we leave to custom realm of the optimized immutable.Map implementations, we see the results we expect.

Correcting our Foo implementation

To set the record straight, lets put our object in order and override Object.hashCode:

override def hashCode(): Int = value.hashCode()

And now let’s re-run our tests:

scala> val upToFourMap = Map(new Foo(1) -> 1, new Foo(2) -> 2, new Foo(3) -> 3, new Foo(4) -> 4)

scala> upToFourMap.getOrElse(new Foo(1), -1)
res1: Int = 1

scala> val upToFiveMap = Map(new Foo(1) -> 1, new Foo(2) -> 2, new Foo(3) -> 3, new Foo(4) -> 4, new Foo(5) -> 5)

scala> upToFiveMap.getOrElse(new Foo(1), -1)
res3: Int = 1

We now see that once hashCode is in order, the results line up.

Summing up

In this post we looked at a corner case which we discovered through our faulty implementation mixed in with special immutable.Map implementation. When one uses a custom object as key, make sure to implement the prerequisites we’ve mentioned. This inconsistency, although caused by our flawed implementation, can be quite surprising.

Leveraging Spark Speculation To Identify And Re-Schedule Slow Running Tasks


Sparks Speculation engine is a tool able to detect slow running tasks (we’ll soon see what that means) and re-schedule their execution. This can be especially helpful in jobs that require a strict batch to processing time ratio. In this post, we’ll go through the speculation algorithm and see an example of how one can put it to good use.

The speculation algorithm

The “spark.speculation” key in Spark Configuration page says:

If set to “true”, performs speculative execution of tasks. This means if one or more tasks are running slowly in a stage, they will be re-launched.

Sparks Speculation allows us to express a “task running slowly in a stage” by the following flags:

Key Default Value Description
spark.speculation false If set to “true”, performs speculative execution of tasks. This means if one or more tasks are running slowly in a stage, they will be re-launched.
spark.speculation.interval 100ms How often Spark will check for tasks to speculate.
spark.speculation.multiplier 1.5 How many times slower a task is than the median to be considered for speculation.
spark.speculation.quantile 0.75 Percentage of tasks which must be complete before speculation is enabled for a in particular stage.

You can add these flags to your spark-submit, passing them under --conf (pretty straightforward, as with other configuration values), e.g.:

spark-submit \
--conf "spark.speculation=true" \
--conf "spark.speculation.multiplier=5" \
--conf "spark.speculation.quantile=0.90" \
--class "org.asyncified.myClass" "path/to/uberjar.jar"

Each flag takes part in the general algorithm in the detection of a slow task. Let’s now break down the algorithm and see how it works.

For each stage in the current batch:

  1. Look if the amount of finished tasks in the stage are equal to or greater than speculation.quantile * numOfTasks, otherwise don’t speculate. (i.e. if a stage has 20 tasks and the quantile is 0.75, at least 15 tasks need finish for speculation to start).
  2. Scan all successfully executed tasks in the given stage and calculate a median of task execution time.
  3. Calculate the threshold that a task has to exceed in order to be eligible for re-launching. This is defined by speculation.multiplier * median.
  4. Iterate all the current running tasks in the stage. If there exists a task that’s running time is larger than threshold, re-launch the task.

That sounds great and all, but why would I want to use this?

Allow me to tell a short story: We run a solution based on Spark Streaming which needs to operate 24x7. In preparation for making this job production ready, I stress tested our Streaming job to make sure we’re going to be able to handle the foreseen production scale. Our service is part of a longer pipeline, and is responsible for consuming data from Kafka, doing some stateful processing and passing data along to the next component in the pipeline. Our stateful transformations uses Amazon S3 to store intermediate checkpoints in case of job failure.

During the stress tests run, approximately after 15-20 hours, I noticed a strange behavior. During execution of an arbitrary batch, one task would block indefinitely on a socketRead0 call with the following stack trace: Method) ….. org.apache.hadoop.fs.s3native.NativeS3FileSystem.getFileStatus( org.apache.hadoop.fs.FileSystem.exists( org.apache.spark.rdd.ReliableCheckpointRDD$.writePartitionToCheckpointFile(ReliableCheckpointRDD.scala:168) org.apache.spark.rdd.ReliableCheckpointRDDanonfun$writeRDDToCheckpointDirectory$1.apply(ReliableCheckpointRDD.scala:136) org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) org.apache.spark.executor.Executor$ java.util.concurrent.ThreadPoolExecutor.runWorker( java.util.concurrent.ThreadPoolExecutor$

If I kill and re-run the same task, it would run fine without hanging. This wasn’t an easily reproducible bug. (Turns out we were hitting a bug (more accurately HTTPCLIENT-1478) in HttpComponents.HttpClient, a library used by Spark and org.apache.hadoop.fs.s3native.NativeS3FileSystem to perform HTTP/API calls to S3).

The problem with the hang was that Spark has to wait for a every one of the stages in a batch to complete before starting to process the next one in queue, and if a single task in the batch is stuck, the entire streaming job now starts to accumulate incoming future batches, only to resume once the previous batch completed.

The problem was that the task itself was actually hanging indefinitely. We were in quite a pickle…

There were a few possible solutions to the problem:

  1. Pass the bugless HttpComponents.HttpClient to Sparks master/worker nodes classpath - If we can get Spark to see the newer version of the library before loading it’s own problematic version containing the bug (with hope that the API compatibility was kept between versions), this might work. Problem is, anyone who had to go head to head with Spark and class loading prioritization knows this is definitely not a pleasant journey to get into.

  2. Upgrade to Spark 2.0.0 - Spark 2.0.0 comes with v4.5.4 of the library, which fixed the bug we were hitting. This was problematic as we’re already running Spark in production, and we’d have to do full regression, not to speak of making sure everything is compatible between our current running version (1.6.1) and 2.0.0.

  3. Turn on speculation - Turn on speculation and “teach it” to identify the rebellious task, kill it and re-schedule.

The choice was split into two. The long run solution is of course upgrading to Spark 2.0.0. But at this point in time, we already have a running streaming job in production, and we need this workaround to work. So we choose to turn on speculation while we planning the upgrade.

But wait… isn’t this a hack?

Yes, this is definitely a workaround to the actual problem. It doesn’t actually solve HttpClient arbitrarily hanging every now and then. But when running in production, you’re sometimes need to “buy time” until you can actually come up with a suitable solution, and speculation hit the nail on the head and allowed the streaming job to run continuously.

Speculation doesn’t magically make your job work faster. Re-scheduling actually has an overhead that will cause the particular batch to incur a scheduling and processing delay, so it needs to be used wisely and with caution.

How will I know a task was marked for re-scheduling due to speculation?

In the Spark UI, when drilling down into a specific stage, one can view all the running/completed tasks per that stage. If a task was re-scheduled, it’s name will be decorated with “(speculated)”, and you’ll usually see a different locality for that given task (Usually, it will run on locality ANY).

Tuning speculation to fit your requirements

When turning on speculation we tried hard to configure it to spot the task that was getting blocked due to the bug. Our bug would reproduce once every 15-20 hours, and it would cause the task to block indefinitely, which is pretty easy to identify. We set the multiplier so it would identify tasks taking 8 times longer than others, and kept the quantile as is (0.75). We didn’t care that it would take time until the problematic task was identified and “eligible” for re-scheduling, as long as other tasks weren’t mistakenly marked as for re-scheduling for no reason.

Wrapping up

Spark speculation is a great tool to have at your disposal. It can identify out of the ordinary slow running tasks and re-schedule them to execute on a different worker. This can be helpful, as we’ve seen, in cases where a slow running task might be caused due to a bug. If you decide to turn it on, take the time to adjust it with the right configuration for you, it is a trail and error path until you feel comfortable with the right settings. It is important to note that speculation isn’t going to fix your slow job and that’s not it’s purpose at all. If your overall job performance is lacking, speculation is not the weapon of choice.


For the brave, here’s the actual piece of code that does the calculation for speculation. It is defined inside TaskSetManager.checkSpeculatableTasks:

   * Check for tasks to be speculated and return true if there are any. This is called periodically
   * by the TaskScheduler.
   * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
   * we don't scan the whole task set. It might also help to make this sorted by launch time.
  override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = {
    // Can't speculate if we only have one task, and no need to speculate if the task set is a
    // zombie.
    if (isZombie || numTasks == 1) {
      return false
    var foundTasks = false
    val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
    logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
    if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
      val time = clock.getTimeMillis()
      val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
      val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1))
      val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation)
      // TODO: Threshold should also look at standard deviation of task durations and have a lower
      // bound based on that.
      logDebug("Task length threshold for speculation: " + threshold)
      for ((tid, info) <- taskInfos) {
        val index = info.index
        if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
          !speculatableTasks.contains(index)) {
            "Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms"
              .format(index,,, threshold))
          speculatableTasks += index
          foundTasks = true

Exploring Stateful Streaming with Apache Spark


Apache Spark is composed of several modules, each serving different purposes. One of it’s powerfull modules is the Streaming API, which gives the developer the power of working with a continuous stream (or micro batches to be accurate) under an abstraction called Discretized Stream, or DStream.

In this post, I’m going to dive into a particular property of Spark Streaming, it’s stateful streaming API. Stateful Streaming enables us to maintain state between micro batches, allowing us to form sessionization of our data.

Disclaimer - One should have basic understanding of how Spark works and the general understanding of the DStream abstraction in order to follow the flow of this post. If not, go ahead and read this, don’t worry, I’ll wait…

Welcome back! let’s continue.

Understanding via an example

In order to understand how to work with the APIs, let’s create a simple example of incoming data which requires us to sessionize. Our input stream of data will be that of a UserEvent type:

case class UserEvent(id: Int, data: String, isLast: Boolean)

Each event describes a unique user. We identify a user by his id, and a String representing the content of the event that occurred. We also want to know when the user has ended his session, so we’re provided with a isLast flag which indicates the end of the session.

Our state, responsible for aggregating all the user events, will be that of a UserSession type:

case class UserSession(userEvents: Seq[UserEvent])

Which contains the sequence of events that occurred for a particular user. For this example, we’ll assume our data source is a stream of JSON encoded data consumed from Kafka.

Our Id property will be used as the key, and the UserEvent will be our value. Together, we get a DStream[(Int, UserEvent)].

Before we get down to business, two key important key points:

1. Checkpointing is preliminary for stateful streaming

From the Spark documentation:

A streaming application must operate 24/7 and hence must be resilient to failures unrelated to the application logic (e.g., system failures, JVM crashes, etc.). For this to be possible, Spark Streaming needs to checkpoint enough information to a fault- tolerant storage system such that it can recover from failures.

Sparks mechanism of checkpointing is the frameworks way of guaranteeing fault tolerance through the lifetime of our spark job. When we’re operating 24/7, things will fail that might not be directly under our control, such as a network failure or datacenter crashes. To promise a clean way of recovery, Spark can checkpoint our data every interval of our choosing to a persistent data store, such as Amazon S3, HDFS or Azure Blob Storage, if we tell it to do so.

Checkpointing is a feature for any non-stateful transformation, but it is mandatory that you provide a checkpointing directory for stateful streams, otherwise your application won’t be able to start.

Providing a checkpoint directory is as easy as calling the StreamingContext with the directory location:

val sparkContext = new SparkContext()
val ssc = new StreamingContext(sparkContext, Duration(4000))

One important thing to be noted is that checkpointed data is only useable as long as you haven’t modified existing code, and is mainly suitable to recover from job failure. Once you’ve modified your code (i.e uploaded a new version to the spark cluster), the checkpointed data is no longer compatible and must be deleted in order for your job to be able to start.

2. Key value pairs in the DStream

A common mistake is to wonder why we’re not seeing stateful transformation methods (updateStateByKey and mapWithState as we’ll soon see) when working with a DStream. Stateful transformations require that we operate on a DStream which encapsulates a key value pair, in the form of DStream[(K, V)] where K is the type of the key and V is type the value. Working with such a stream allows Spark to shuffle data based on the key, so all data for a given key can be available on the same worker node and allow you to do meaningful aggregations.

Ok, we’re ready. Let’s go write some code.

A Brief Look At The Past

Until Spark 1.6.0, the sole stateful transformation available was PairDStreamFunctions.updateStateByKey.

The signature for the simplest form (which we’ll look at) looks like this:

def updateStateByKey[S](updateFunc: (Seq[V], Option[S])  Option[S])

updateStateByKey requires a function which accepts:

  1. Seq[V] - The list of new values received for the given key in the current batch
  2. Option[S] - The state we’re updating on every iteration.

For the first invocation of our job, the state is going to be None, signaling it is the first batch for the given key. After that it’s entirely up to us to manage it’s value. Once we’re done with a particular state for a given key, we need to return None to indicate to Spark we don’t need the state anymore.

A naïve implementation for our scenario would look like this:

def updateUserEvents(newEvents: Seq[UserEvent],
                    state: Option[UserSession]): Option[UserSession] = {
  Append the new events to the state. If this the first time we're invoked for the key
  we fallback to creating a new UserSession with the new events.
  val newState = state
    .map(prev => UserSession(prev.userEvents ++ newEvents))

If we received the `isLast` event in the current batch, save the session to the underlying store and return None to delete the state.
Otherwise, return the accumulated state so we can keep updating it in the next batch.
  if (newEvents.exists(_.isLast)) {
  } else newState

At each batch, we want to take the state for the given user and concat both old events and new events into a new Option[UserSession]. Then, we want to check if we’ve reached the end of this users session, so we check the newly arrived sequence for the isLast flag on any of the UserEvents. If we received the last message, we save the user action to some persistent storage, and then return None to indicate we’re done. If we haven’t received an end message, we simply return the newly created state for the next iteration.

Our Spark DAG (Directed Acyclic Graph) looks like this:

val kafkaStream =
      KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](ssc, kafkaParams, topics)


The first map is for parsing the JSON to a tuple of (Int, UserEvent), where the Int is Then we pass the tuple to our updateStateByKey to do the rest.

Caveats of updateStateByKey

  1. A major downside of using updateStateByKey is the fact that for each new incoming batch, the transformation iterates the entire state store, regardless of whether a new value for a given key has been consumed or not. This can effect performance especially when dealing with a large amount of state over time. There are various technics to improving performance, but this still is a pain point.

  2. No built in timeout mechanism - Think what would happen in our example, if the event signaling the end of the user session was lost, or hadn’t arrived for some reason. One upside to the fact updateStateByKey iterates all keys is that we can implement such a timeout ourselves, but this should definitely be a feature of the framework.

  3. What you receive is what you return - Since the return value from updateStateByKey is the same as the state we’re storing. In our case Option[UserSession], we’re forced to return it downstream. But what happens if once the state is completed, I want to output a different type and use that in another transformation? Currently, that’s not possible.

Introducing mapWithState

mapWithState is updateStateByKeys successor released in Spark 1.6.0 as an experimental API. It’s the lessons learned down the road from working with stateful streams in Spark, and brings with it new and promising goods.

mapWithState comes with features we’ve been missing from updateStateByKey:

  1. Built in timeout mechanism - We can tell mapWithState the period we’d like to hold our state for in case new data doesn’t come. Once that timeout is hit, mapWithState will be invoked one last time with a special flag (which we’ll see shortly).

  2. Partial updates - Only keys which have new data arrived in the current batch will be iterated. This means no longer needing to iterate the entire state store at every batch interval, which is a great performance optimization.

  3. Choose your return type - We can now choose a return type of our desire, regardless of what type our state object holds.

  4. Initial state - We can select a custom RDD to initialize our stateful transformation on startup.

Let’s take a look at the different parts that form the new API.

The signature for mapWithState:

mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType])

As opposed to the updateStateByKey that required us to pass a function taking a sequence of messages and the state in the form of an Option[S], we’re now required to pass a StateSpec:

Abstract class representing all the specifications of the DStream transformation mapWithState operation of a pair DStream (Scala) or a JavaPairDStream (Java). Use the StateSpec.apply() or StateSpec.create() to create instances of this class.

Example in Scala:

// A mapping function that maintains an integer state and return a String
def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
  // Use state.exists(), state.get(), state.update() and state.remove()
  // to manage state, and return the necessary string

val spec = StateSpec.function(mappingFunction)

The interesting bit is StateSpec.function, a factory method for creating the StateSpec. it requires a function which has the following signature:

mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType

mappingFunction takes several arguments. Let’s construct them to match our example:

  1. KeyType - Obviously the key type, Int
  2. Option[ValueType] - Incoming data type, Option[UserEvent]
  3. State[StateType] - State to keep between iterations, State[UserSession]
  4. MappedType - Our return type, which can be anything. For our example we’ll pass an Option[UserSession].

Differences between mapWithState and updateStateByKey

  1. The value of our key, which previously wasn’t exposed.
  2. The incoming new values in the form of Option[S], where previously it was a Seq[S].
  3. Our state is now encapsulated in an object of type State[StateType]
  4. We can return any type we’d like from the transformation, no longer bound to the type of the state we’re holding.

(There exists a more advanced API where we also receive a Time object, but we won’t go into that here. Feel free to check out the different overloads here).

Exploring state management with the State object

Previously, managing our state meant working with an Option[S]. In order to update our state, we would create a new instance and return that from our transformation. When we wanted to remove the state, we’d return None. Since we’re now free to return any type from mapWithState, we need a way to interact with Spark to express what we wish to do with the state in every iteration. For that, we have the State[S] object.

There are several methods exposed by the object:

  1. def exists(): Boolean - Checks whether a state exists
  2. def get(): S - Get the state if it exists, otherwise it will throw java.util.NoSuchElementException. (We need to be careful with this one!)
  3. def isTimingOut(): Boolean - Whether the state is timing out and going to be removed by the system after the current batch.
  4. def remove(): Unit - Remove the state if it exists.
  5. def update(newState: S): Unit - Update the state with a new value
  6. def getOption(): Option[S] - Get the state as an scala.Option.

Which we will soon see.

Changing our code to conform to the new API

Let’s rebuild our previous updateUserEvents to conform to the new API. Our new method signature now looks like this:

def updateUserEvents(key: Int, value: Option[UserEvent], state: State[UserSessions]): Option[UserSessions]

Instead of receiving a Seq[UserEvent], we now receive each event individually.

Let’s go ahead and make those changes:

def updateUserEvents(key: Int,
                     value: Option[UserEvent],
                     state: State[UserSessions]): Option[UserSessions] = {
  Get existing user events, or if this is our first iteration
  create an empty sequence of events.
  val existingEvents: Seq[UserEvent] =

  Extract the new incoming value, appending the new event with the old
  sequence of events.
  val updatedUserSessions: UserSessions =
      .map(newEvent => UserSessions(newEvent +: existingEvents))

Look for the end event. If found, return the final `UserSessions`,
If not, update the internal state and return `None`
  updatedUserSessions.userEvents.find(_.isLast) match {
    case Some(_) =>
    case None =>

For each iteration of mapWithState:

  1. In case this is our first iteration the state will be empty. We need to create it and append the new event. if it isn’t, we already have existing events, extract them from the State[UserSession] and append the new event with the old events.
  2. Look for the isLast event flag. If it exists, remove the UserSession state and return an Option[UserSession]. Otherwise, update the state and return None

The choice to return Option[UserSession] the transformation is up to us. We could of chosen to return Unit and send the complete UserSession from mapWithState as we did with updateStateByKey. But, I like it better that we can pass UserSession down the line to Another tranformation to do more work as needed.

Our new Spark DAG now looks like this:

val stateSpec = StateSpec.function(updateUserEvents _)


But, there’s one more thing to add. Since we don’t save the UserSession inside the transformation, we need to add an additional transformation to store it in the persistent storage. For that, we can use foreachRDD:

  .foreachRDD { rdd =>
    if (!rdd.isEmpty()) {
      rdd.foreach(maybeUserSession => maybeUserSession.foreach(saveUserSession))

(If the connection to the underlying persistent storage is an expensive one which you don’t want to open foreach value in the RDD, consider using rdd.foreachPartition instead of rdd.foreach (but that is beyond the scope of this post)

Finishing off with timeout

In reality, when working with large amounts of data we have to shield ourselves from data lose. With our current implementation if the isLast even doesn’t show, we’ll end up with that users actions “stuck” in the state.

Adding a timeout is simple:

  1. Add the timeout when constructing our StateSpec.
  2. Handle the timeout in the stateful transformation.

The first step is easily achieved by:

import org.apache.spark.streaming._
val stateSpec =
    .function(updateUserEvents _)

(Minutes is a Sparks wrapper class for Scala’s Duration class.)

For our updateUserEvents, we need to monitor the State[S].isTimingOut flag to know we’re timing out. Two things I want to mention in regards to timing out:

  1. It’s important to note that once a timeout occurs, our value argument will be None (explaining why we recieve an Option[S] instead of S for value. More on that here).
  2. If mapWithState is invoked due to a timeout, we must not call state.remove(), that will be done on our behalf by the framework. From the documentation of State.remove:

State cannot be updated if it has been already removed (that is, remove() has already been called) or it is going to be removed due to timeout (that is, isTimingOut() is true).

Let’s modify the code:

def updateUserEvents(key: Int,
                     value: Option[UserEvent],
                     state: State[UserSessions]): Option[UserSessions] = {
  def updateUserSessions(newEvent: UserEvent): Option[UserSessions] = {
    val existingEvents: Seq[UserEvent] =

    val updatedUserSessions = UserSessions(newEvent +: existingEvents)

    updatedUserSessions.userEvents.find(_.isLast) match {
      case Some(_) =>
      case None =>

  value match {
    case Some(newEvent) => updateUserSessions(newEvent)
    case _ if state.isTimingOut() => state.getOption()

I’ve extracted the updating of the user actions to a local method, updateUserSessions, which we call if we’re invoked as a result of a new incoming value. Otherwise, we’re timing out we need to return user events we’ve accumulated thus far.

Wrapping Up

I hope I’ve managed to convey the general use of Spark’s stateful streams. Stateful streams, especially the new mapWithState transformation bring alot of power to the end-users who wish to work with stateful data with Spark while enjoying the guarantee Spark brings of resiliency, distribution and fault tolerance.

There are still improvements to be made in the forth coming Spark 2.0.0 release and beyond such as state versioning, which will enable us to label our accumulated data, and only persist a subset of the state we store. If you’re interested in more, see the “State Store for Streaming Aggregation” proposal.

Additionally, there is great comparison of updateStateByKey and mapWithState performance characteristics in this DataBricks post.

Decompiling Traits with Scala 2.11 and 2.12


When working with Scala, one quickly gets familiar with the notation of Traits. Traits enable powerful language features such as Subtype Polymorphism and Mixins.

Scala 2.12 (currently at preview stage M4) takes advantage of Java 8 Default Methods to generate optimized JVM byte code for traits. We’ll take a short detour of how Scala 2.11 generates byte code for traits and then look at Scala 2.12 , both compiled with Java 8.

Compiling traits with Scala 2.11

Traits in Scala allow us to provide a default implementation to interface methods:

trait X {
  def helloWorld: String = "hello world"

Using Scala on the JVM, the compiler needs to work around the fact that Java <= v7 doesn’t allow default implementations on interfaces, which is the relative cousin of traits. Lets see what the Scala compiler does to work around that. We’ll compile trait X and then look at the generated byte code:

[root@localhost yuvie]# scalac X.scala
[root@localhost yuvie]# cd yuvie/
[root@localhost yuvie]# ll
total 12
-rw-r--r--. 1 root   root   448 Jun 24 17:58 X.class
-rw-r--r--. 1 root   root   416 Jun 24 17:58 X$class.class

Scala generated two class files for our X trait. One called X.class, and one called X$class.class. Let’s look at what each of these contain:

[root@localhost yuvie]# javap -c -p X
Warning: Binary file X contains yuvie.X
Compiled from "X.scala"
public interface yuvie.X {
  public abstract java.lang.String helloWorld();

[root@localhost yuvie]# javap -c -p X\$class.class 
Compiled from "X.scala"
public abstract class yuvie.X$class {
  public static java.lang.String helloWorld(yuvie.X);
       0: ldc           #9                  // String hello world
       2: areturn

  public static void $init$(yuvie.X);
       0: return

Looking at the code we can see that Scalas compiler generates:

  1. An interface called X, matching the trait declaration. This interface has a single method called helloWorld without the implementation.
  2. An abstract class called X$class with a static helloWorld method, which provides the implementation.

Lets see what happens when we extend X with some class M:

class M extends X


[root@localhost yuvie]# javap -c -p M.class 
Compiled from "M.scala"
public class yuvie.M implements yuvie.X {
  public java.lang.String helloWorld();
       0: aload_0
       1: invokestatic  #17                 // Method yuvie/X$class.helloWorld:(Lyuvie/X;)Ljava/lang/String;
       4: areturn

  public yuvie.M();
       0: aload_0
       1: invokespecial #23                 // Method java/lang/Object."<init>":()V
       4: aload_0
       5: invokestatic  #27                 // Method yuvie/X$class.$init$:(Lyuvie/X;)V
       8: return

When we extend / mixin X and want to invoke helloWorld, a call to the abstract class X$class is made which provides the implementation for helloWorld. This duo of interface and abstract class allows Scala to generate valid byte code.

Leveraging Default Methods with Scala 2.12

We just saw how Scala 2.11 deals with compiling traits, let’s see how Scala 2.12 leverages default methods. Taking the same code and compiling it again only now with Scala 2.12-M4 yields the following:

[root@localhost test-project]# cd target/scala-2.12.0-M4/classes/yuvie/
[root@localhost yuvie]# ll
total 16
-rw-rw-r--. 1 yuvali yuvali  680 Jun 24 18:17 X.class

[root@localhost yuvie]# javap -c -p X.class
Compiled from "X.scala"
public interface yuvie.X {
  public java.lang.String helloWorld();
       0: ldc           #12                 // String hello world
       2: areturn

  public void $init$();
       0: return

The compiler generates a single interface called X which has the default implementation of the trait. Let’s see what happens now when we extend M with X:

[root@localhost yuvie]# javap -c -p M.class 
Compiled from "M.scala"
public class yuvie.M implements yuvie.X {
  public java.lang.String helloWorld();
       0: aload_0
       1: invokespecial #14                 // Method yuvie/X.helloWorld:()Ljava/lang/String;
       4: areturn

  public yuvie.M();
       0: aload_0
       1: invokespecial #20                 // Method java/lang/Object."<init>":()V
       4: aload_0
       5: invokespecial #23                 // Method yuvie/X.$init$:()V
       8: return

The compiler now generates an invokespecial instruction for the instance method created in X, instead of invokestatic previously for the static method created inside X$class.class.


Thanks to Java 8 and Default Method, Scala 2.12 is now able to emit less byte code for traits. My example was very simplified, and in real case scenarios this should save decent amount of bytecode generation for the Scala compiler and enables Scala to be more aligned with Java without needing to jump the hoops to make traits work.