Exploring Stateful Streaming with Spark Structured Streaming

In a previous post, we explored how to do stateful streaming using Sparks Streaming API with the DStream abstraction. Today, I’d like to sail out on a journey with you to explore Spark 2.2 with its new support for stateful streaming under the Structured Streaming API. In this post, we’ll see how the API has matured and evolved, look at the differences between the two approaches (Streaming vs Structured Streaming), and see what changes were made to the API. We’ll do so by taking the example from my previous post, and adapting it to the new API.

A recap of state management with Spark Streaming

If you needed to use stateful streaming with Spark you had to choose between two abstractions (up until Spark 2.2). updateStateByKey and mapWithState where the latter is more or less an improvement (both API and performance wise) version of the former (with a bit of different semantics). In order to utilize state between micro batches, you provided a StateSpec function to mapWithState which would be invoked per key value pair that arrived in the current micro batch. With mapWithState, the main advantage points are:

  1. Initial RDD for state - One can load up an RDD with previously saved state
  2. Timeout - Timeout management was handled by Spark. You can set a single timeout for all key value pairs.
  3. Partial updates - Only keys which were “touched” in the current micro batch were iterated for update
  4. Return type - You can choose any return type of your choice.

But thing aren’t always perfect…

Pain points of mapWithState

mapWithState was a big improvement over the previous updateStateByKey API. But there are a few caveats I’ve experienced over the last year while using it:

Checkpointing

To ensure Spark can recover from failed tasks, it has to checkpoint data to a distributed file system from which it can consume upon failure. When using mapWithState, each executor process is holding a HashMap, in memory, of all the state you’ve accumulated. At every checkpoint, Spark serializes the entire state, each time. If you’re holding a lot of state in memory, this can cause significant processing latencies. For example, under the following set up:

  • Batch interval: 4 seconds
  • Checkpointing interval: 40 seconds (4 second batch x 10 constant spark factor)
  • 80,000 messages/second
  • Message size: 500B - 2KB
  • 5 m4.2xlarge machines (8 vCPUs, 32GB RAM)
  • 2 Executors per machine
  • Executor storage size ~ 7GB (each)
  • Checkpointing data to HDFS

I’ve experienced accumulated delays of up to 4 hours, since each checkpoint under high load was taking between 30 seconds - 1 minute for the entire state and we’re generating batches every 4 seconds. I’ve also seen people confused by this on StackOverflow as it really isn’t obvious why some batches take drastically longer than others.

If you’re planning on using stateful streaming for high throughput you have to consider this as a serious caveat. This problem was so severe that it sent me looking for an alternative to using in memory state with Spark. But we’ll soon see that things are looking bright ;)

Saving state between version updates

Software is an evolving process, we always improve, enhance and implement new feature requirements. As such, we need to be able to upgrade from one version to another, preferably without affecting existing data. This becomes quite tricky with in memory data. How do we preserve the current state of things? How do we ensure that we continue from where we left off?

Out of the box, mapWithState doesn’t support evolving our data structure. If you’ve modified the data structure you’re storing state with, you have to delete all previously checkpointed data since the serialVersionUID will differ between object versions. Additionally, any change to the execution graph defined on the StreamingContext won’t take effect since we’re restoring the linage from checkpoint.

mapWithState does provide a method for viewing the current state snapshot of our data via MapWithStateDStream.stateSnapshot(). This enables us to store state at an external repository and be able to recover from it using StateSpec.initialRDD. However, storing data externally can increase the already-significant-delays due to checkpoint latencies.

Separate timeout per state object

mapWithState allows us to set a default timeout for all states via StateSpec.timeout. However, at times it may be desired to have separate state timeout for each state object. For example, assume we have a requirement that a user session be no longer than 30 minutes. Then comes along a new client which wants to see user sessions end every 10 minutes, what do we do? Well, we can’t handle this out of the box and we have to implement our own mechanism for timeout. The bigger problem is that mapWithState only touches key value pairs which we have data for in the current batch, it doesn’t touch all the keys. This means that we have to role back to updateStateByKey which by default iterates the entire state, which may be bad for performance (depending on the use case, of course).

Single executor failure causing data loss

Executors are java processes, and as for any process they can fail. I’ve had heap corruptions in production cause a single executor to die. The problem with this is that once a new executor is created by the Worker process, it does not recover the state from checkpoint. If you look at the CheckpointSuite tests, you’ll see that all of them deal with StreamingContext recovery, but none for single executor failure.

Ok, downsides, fine. Where is this new API you’re talking about?

Hold your horses, we’re just getting to it… :)

Introducing Structured Streaming

Structured Streaming is Sparks new shiny tool for reasoning about streaming.

From the Structured Streaming Documentation - Overview:

Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive.

Sparks authors realize that reasoning about a distributed streaming application has many hidden concerns one may or may not realize he/she has to deal with other than maintaining the business domain logic. Instead of taking care of all these concerns, they want us to reason about our stream processing the same way we’d use a static SQL table by generating queries while take care of running them over new data as it comes into our stream. Think about it as a unbounded table of data.

For a deeper explanation on Structured Streaming and the Dataset[T] abstraction, see this great post by DataBricks. Don’t worry, I’ll wait..

Welcome back. Let’s continue on to see what the new stateful streaming abstraction looks like in Structured Streaming

Learning via an example

The example I’ll use here is the same example I’ve used in my previous post regarding stateful streaming. To recap (and for those unfamiliar with the previous post), the example talks about a set of incoming user events which we want to aggregate as they come in from the stream. Our events are modeled in the UserEvent class:

case class UserEvent(id: Int, data: String, isLast: Boolean)

We uniquely identify a user by his id. This id will be used to group the incoming data together so that we get all user events to the same executor process which handles the state. A user event also has a data field which generates some String content and an additional Boolean field indicating if this is the last message for the current session.

The way we aggregate user events as they come in the stream is by using a UserSession class:

case class UserSession(userEvents: Seq[UserEvent])

Which holds together all events by a particular user.

Introducing mapGroupsWithState:

If you think “Hmm, this name sounds familiar”, you’re right, it’s almost identical to the mapWithState abstraction we had in Spark Streaming with minor changes to on the user facing API. But first, let’s talk about some key differences between the two.

Differences between mapWithState and mapGroupsWithState (and generally Spark Structured VS Streaming)

  1. Keeping state between application updates - One of the biggest caveats of mapWithState is the fact that unless you roll out your own bookkeeping of state you’re forced to drop the in-memory data between upgrades. Not only that, but if anything inside the Spark DAG changes you have to drop that data as well. From my experiments with mapGroupsWithState, it seems as using the Kryo encoder combined with versioning your data structures correctly (i.e using default values for newly added state), allows you to keep the data between application upgrades and also change the Spark DAG defining your transformations and still keep the state. This is major news for anyone using mapWithState. The reason I’m cautious saying this is because I haven’t seen any official documentation or statements from the developers of Structured Streaming to support this claim.

  2. Micro batch execution - Spark Streaming requires a fix batch interval in order to generate and execute micro batches with data from source. Even if no new data has arrived the micro batch will still execute which will cause the entire graph execute. Structured Streaming is different, it has a dedicated thread which checks the source constantly to see if new data has arrived. If no new data is available, the query will not be executed. What does that mean? It means, for example, that if you set a timeout interval of X seconds but new data hasn’t come into the stream, no state will timeout because it won’t run the query.

  3. Internal data storage for state - mapWithState is based on an OpenHashMapBasedStateMap[K, V] implementation to store the in memory state. mapGroupsWithState uses java.util.ConcurrentHashMap[K, V]. Additionally, the latter uses an underlying structure called UnsafeRow for both key and value instead of a plain JVM object. These unsafe rows are wrappers around the bytes of data generated by the encoder for the keys and values, and applies on demand conversions between the unsafe representation to our JVM object structure when it needs to pass the key value pair to our state function.

  4. State versioning - there is no longer a single state store per executor in memory. The new implementation uses a HDFSBackedStateStore per version of the state, meaning it only needs to keep the latest version of the state in memory while letting older versions reside in the backing state store, loading them as needed on demand.

  5. Timeout - mapWithState has a single timeout set for all state objects. mapGroupsWithState enables state timeout per group, meaning you can create more complex configurations for timeouts. Additionally, a timeout based on event time and watermarks are available.

  6. Checkpointing - mapWithState checkpoint occurs every fixed interval. If our batch time is 4 seconds, a checkpoint for the entire state in memory would occur every 40 seconds. Not only that, but a checkpoint is a blocking operation, meaning that until it finishes we cannot process incoming events of the next batch. mapGroupsWithState checkpointing is done incrementally for updated keys only and this is abstracted away by the implementation of the underlying FileSystem used as the state store. This means that there should be a significant reduction in checkpointing overhead.

  7. Offset Handling (For replayable sources, such as Kafka) - Using the DStream API with a replayable source such as Kafka would require us to reason about offset storage ourselves in persistent storage such as ZooKeeper, S3 or HDFS. Replaying a certain source from a particular offset means reading the data from the persistent storage and passing it to KafkaUtil.createDirectStream upon initialization of the stream. Structured Streaming stores and retrieves the offsets on our behalf when re-running the application meaning we no longer have to store them externally.

Ok, enough with the comparisons, let’s get to business.

Analyzing the API

Let’s look at the method signature for mapGroupsWithState:

def mapGroupsWithState[S: Encoder, U: Encoder](
      timeoutConf: GroupStateTimeout)(
      func: (K, Iterator[V], GroupState[S]) => U)

Let’s break down each argument and see what we can do with it. The first argument contains a timeoutConf which is responsible for which timeout configuration we want to choose from. We have two options:

  1. Processing Time Based (GroupStateTimeout.ProcessingTimeTimeout) - Timeout based on a constant interval (similar to calling the timeout function on StateSpec in Spark Streaming)

  2. Event Time Based (GroupStateTimeout.EventTimeTimeout) - Timeout based on a user defined even time and watermark (read this for more about handling late data using watermarks).

In the second argument list we have our state function. Let’s examine each argument and what it means:

func: (K, Iterator[V], GroupState[S] => U)

There are three argument types, K, Iterator[V], and GroupState[S] and a return type of type U. Lets map each of these arguments to our example and fill in the types.

As we’ve seen, we have a stream of incoming messages of type UserEvent. This class has a field called id of type Int which we’ll use as our key to group user events together. This means we substitute K with Int:

(Int, Iterator[V], GroupState[S]) => U

Next up is Iterator[V]. V is the type of the values we’ll be aggregating. We’ll be receiving a stream of UserEvent and that means we need to substitute that with V:

(Int, Iterator[UserEvent], GroupState[S]) => U

Great! Which class describes our state? If you scroll up a bit, you’ll see we’ve defined a class called UserSession which portraits the entire session of the user, and that’s what we’ll use as our state type! Let’s substitute S with UserSession

(Int, Iterator[UserEvent], GroupState[UserSession]) => U

Awesome, we’ve managed to fill in the types of the arguments. The return type, U is what we have left. We only want to return a UserSession once it’s complete, either by the user session timing out or receiving the isLast flag set to true. We’ll set the return type to be an Option[UserSession] which will be filled iff we’ve completed the session. This means substituting U with Option[UserSession]:

(Int, Iterator[UserEvent], GroupState[UserSession]) => Option[UserSession]

Hooray!

The GroupState API

For those of you acquainted with the mapWithState API, GroupState should feel very familiar to the State class. Let’s see what the API is like:

  1. def exists: Boolean - Whether state exists or not.
  2. def get: S - Get the state value if it exists, or throw NoSuchElementException.
  3. def getOption: Option[S] - Get the state value as a Scala Option[T].
  4. def update(newState: S): Unit - Update the value of the state. Note that null is not a valid value, and it throws IllegalArgumentException.
  5. def remove(): Unit - Remove this state.
  6. def hasTimedOut: Boolean - Whether the function has been called because the key has timed out. This can return true only when timeouts are enabled in [map/flatmap]GroupsWithStates.
  7. def setTimeoutDuration(durationMs: Long): Unit - Set the timeout duration in milliseconds for this key. Note ProcessingTimeTimeout must be enabled in [map/flatmap]GroupsWithStates, otherwise it throws an UnsupportedOperationException

The newest member to the state API is the setTimeoutDuration method. I’ve only included a single overload, but there are 3 others taking various input arguments types such as String and java.util.Timestamp. Note that since each group state can have it’s own timeout we have to set it explicitly inside mapGroupsWithState, for each group state. This means that each time our method is called, we’ll have to set the timeout again using setTimeoutDuration, as we’ll see when we implement that method.

In addition, there are several restrictions to calling setTimeoutDuration. If the we haven’t set the timeoutConf argument in mapGroupsWithState, when we call this method it will throw an UnsupportOperationException, so make sure you’ve configured the timeout.

I’ve summarized API documentation here but if you want the full details see the Scala Docs.

Creating custom encoders

For the perceptive amongst the readers, you’ve probably noticed the following constraint on the type parameters of mapGroupsWithState:

def mapGroupsWithState[S: Encoder, U: Encoder]

What is this Encoder class required via context bound on the elements S and U? The class documentation says:

Used to convert a JVM object of type T to and from the internal Spark SQL representation.

Spark SQL is layered on top an optimizer called the Catalyst Optimizer, which was created as part of the Project Tungsten. Spark SQL (and Structured Streaming) deals, under the covers, with raw bytes instead of JVM objects, in order to optimize for space and efficient data access. For that, we have to tell Spark how to convert our JVM object structure into binary and that is exactly what these encoders do.

Without going to lengths about encoders, here is the method signature for the trait:

trait Encoder[T] extends Serializable {

  /** Returns the schema of encoding this type of object as a Row. */
  def schema: StructType

  /**
   * A ClassTag that can be used to construct an Array to contain a collection of `T`.
   */
  def clsTag: ClassTag[T]
}

Any encoder has to provide two things, a schema of the class described via a StructType, which is a recursive data structure laying out the schema of each field in the object we’re describing, and the ClassTag[T] for converting collections with type T.

If we look closely again at the signature for mapGroupsWithState, we see that we need to supply two encoders, one for our state class represented by type S, and one for our return type represented by type U. In our example, that would mean providing implicit evidence for UserSession in the form of an Encoder[UserSession]. But how do we generate such encoders? Spark comes packed with encoders for primitives via the SQLImplicits object, and if we use case classes we’ll have to have an implicit in scope for it. The easiest way to do so is by creating a custom encoder using Encoders.kryo[T]. We use it like this:

object StatefulStructuredSessionization {
  implicit val userEventEncoder: Encoder[UserEvent] = Encoders.kryo[UserEvent]
  implicit val userSessionEncoder: Encoder[UserSession] = Encoders.kryo[UserSession]
}

For more on custom encoders, see this StackOverflow answer.

Implementing our state method

After figuring out what the signature of our state method is, let’s go ahead an implement it:

def updateSessionEvents(
  id: Int,
  userEvents: Iterator[UserEvent],
  state: GroupState[UserSession]): Option[UserSession] = {
if (state.hasTimedOut) {
  // We've timed out, lets extract the state and send it down the stream
  state.remove()
  state.getOption
} else {
  /*
    New data has come in for the given user id. We'll look up the current state
    to see if we already have something stored. If not, we'll just take the current user events
    and update the state, otherwise will concatenate the user events we already have with the
    new incoming events.
  */
  val currentState = state.getOption
  val updatedUserSession = currentState.fold(UserSession(userEvents.toSeq))(currentUserSession => UserSession(currentUserSession.userEvents ++ userEvents.toSeq))
      
  if (updatedUserSession.userEvents.exists(_.isLast)) {
    /*
    If we've received a flag indicating this should be the last event batch, let's close
    the state and send the user session downstream. 
    */
    state.remove()
    updatedUserSession
  } else {  
    state.update(updatedUserSession)   
    state.setTimeoutDuration("1 minute")
    None
   }
  }
}

Our updateUserEvents method has to deal with a couple of flows. We check to see if our method was invoked as a cause of the state timing out and if it has, the state.hasTimedOut method will be set to true and our userEvents iterator will be empty. All we have to do is remove the state and send out our Option[UserSession] down the stream. If it hasn’t timed out that means that the method has been invoked because new values have arrived.

We extract the current state out to the currenState value and deal with two cases by the means of Option.fold:

  1. The state is empty - this means that this is the first batch of values for the given key and all we have to do is take the user events we’ve received and lift them into the UserSession class.

  2. The state has a value - we extract the existing user events from the UserSession object (this is the second argument list in the fold method) and append them to the new values we just received.

The UserEvent class contains a field named isLast which we check to see if this is the last incoming event for the session. After aggregating the values, we scan the user events sequence to see if we’ve received the flag. If we did, we remove the session from the state and return the session, otherwise we update the state and set the timeout duration and return None indicating the session isn’t complete yet.

Fitting it all together

We’ve seen how to define our state method, which means we’re ready to create the execution graph. In this example, I’ll be consuming data from a socket in JSON structure which matches our UserEvent.

I apologize in advance for not including all the code and imports in this gist. At the bottom of the post you’ll find a link to a full working repo on GitHub containing all the code for you to try out.

To start things off, we create a SparkSession instance with the details of the master URI and application name:

val spark: SparkSession = SparkSession.builder
      .master("local[*]")
      .appName("Stateful Structured Streaming")
      .getOrCreate()

SparkSession is our gateway to interaction with the streaming graph, just as StreamingContext previously was. After our session is defined, we express what the format of our source is for consuming the data, in this case its a Socket:

import spark.implicits._

val userEventsStream: Dataset[String] = spark.readStream
    .format("socket")
    .option("host", host)
    .option("port", port)
    .load()
    .as[String]

Importing spark.implicits._ is for the encoders defined for primitives (we use it here for String). The host and port variables come from the command line arguments. Once we call the load() method, we get back a DataFrame. Think of it as a generic representation of the data containing rows and columns. In order to convert a DataFrame into a DataSet[String] we use the as[T] method which tells Spark we want to use get back a typed data set.

After we have the data set at hand, we can map over it to deserialize our JSON into a UserEvent and apply our state method to it:

val finishedUserSessionsStream: Dataset[UserSession] =
  userEventsStream
    .map(deserializeUserEvent)
    .groupByKey(_.id)
    .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout())(updateSessionEvents)
    .flatMap(userSession => userSession)

After mapping over the data set and deserializing all events, we use groupByKey to group user events by their id to get back a KeyValueGroupedDataset[K, V]. The grouping is the key (no pun intended) to exposing mapGroupsWithState which is defined on the key valued data set type. We then call mapGroupsWithState and pass GroupStateTimeout.ProcessingTimeTimeout as the first argument to indicate to Spark how we want to timeout our state, and pass in our updateSessionEvents method we’ve defined beforehand. After we finish applying the stateful transformation we output the completed sessions to the rest of the graph, but we also output None in case the session isn’t complete. This means we have to make sure to keep flowing only Option[T] which contains a Some[UserSession], which is the reason for flatMap.

We’re left to define the output of our stream. For this example I chose the “console” format which just prints out values to the console, but you may use any of the existing output sinks. Additionally, we have to specify which type of output mode we choose (we can only use OutputMode.Update with mapGroupsWithState) and the checkpoint directory location:

finishedUserSessionsStream.writeStream
  .outputMode(OutputMode.Update())
  .format("console")
  .option("checkpointLocation", "hdfs:///path/to/checkpoint/location")
  .start()
  .awaitTermination()

And voila, watch the data start pouring into the stream!

Wrapping up

Structured Streaming brings a new mental model of reasoning about streams. The most promising features I see for mapGroupsWithState is that fact that we no longer suffer the penalty of checkpointing due to the way stateful aggregations are handled. Additionally, being able to save state between version upgrades and get automatic offset management is also very appealing. Time will definitely tell if this is THE stateful management framework we’ve been waiting for, it sure looks and feels promising.

There are more internal implementation details which are interesting which I encourage you to explore and share in the comments as you start using Structured Streaming.

You can find the full working example in my GitHub stateful streaming repository.

When Life Gives You Options, Make Sure Not To Choose Some(null)

It was a casual afternoon at the office, writing some code, working on features. Then, all of the sudden, Spark started getting angry. If you run a production environment with Spark, you probably know that feeling when those red dots start to accumulate in the Spark UI:

Aint nobody got time for that!

It was an old friend, one I haven’t seem in a while. Its one of these friends that creeps on you from behind, when you’re not looking. Even when he knows he’s not invited, he still feels so comfortable coming back, peeping his head every now and then:

java.lang.NullPointerException

Hello nullness my old friend

When we moved from C# to Scala, it was written in bold everywhere: “Use Option[A]! avoid null at all costs”. And that actually made sense. You have this perfectly good data structure (one might even throw the M word) which makes it easy not to use null. It takes some time getting used to especially for someone making his first steps in a functional programming language. So we embraced the advice and started using Option[A] everywhere we wanted to convey the absence of a value, and so far it’s worked great. So how did our old friend still manage to creep in and cause our Spark job to scream and yell at us?

Diagnosing the issue

Spark wasn’t being very helpfull here. Mostly those red dots were accomodated by the NullPointerException and the useless Driver stacktrace, but all the actual action was happening inside the Executor nodes running the code. After some investigation, I managed to get a hold of the actual StackTrace causing the problem:

org.apache.spark.SparkException: Job aborted due to stage failure: Task 7 in stage 15345.0 failed 4 times, most recent failure: Lost task 7.3 in stage 15345.0 (TID 27874, XXX.XXX.XXX.XXX): java.lang.NullPointerException
	at scala.collection.immutable.StringOps$.length$extension(StringOps.scala:47)
	at scala.collection.immutable.StringOps.length(StringOps.scala:47)
	at scala.collection.IndexedSeqOptimized$class.segmentLength(IndexedSeqOptimized.scala:193)
	at scala.collection.immutable.StringOps.segmentLength(StringOps.scala:29)
	at scala.collection.GenSeqLike$class.prefixLength(GenSeqLike.scala:93)
	at scala.collection.immutable.StringOps.prefixLength(StringOps.scala:29)
	at scala.collection.IndexedSeqOptimized$class.span(IndexedSeqOptimized.scala:159)
	at scala.collection.immutable.StringOps.span(StringOps.scala:29)
	at argonaut.PrettyParams.appendJsonString$1(PrettyParams.scala:131)
	at argonaut.PrettyParams.argonaut$PrettyParams$$encloseJsonString$1(PrettyParams.scala:148)
	at argonaut.PrettyParams$$anonfun$argonaut$PrettyParams$$trav$1$4.apply(PrettyParams.scala:187)
	at argonaut.PrettyParams$$anonfun$argonaut$PrettyParams$$trav$1$4.apply(PrettyParams.scala:187)
	at argonaut.Json$class.fold(Json.scala:32)
	at argonaut.JString.fold(Json.scala:472)
	at argonaut.PrettyParams.argonaut$PrettyParams$$trav$1(PrettyParams.scala:178)
	at argonaut.PrettyParams$$anonfun$argonaut$PrettyParams$$trav$1$6$$anonfun$apply$3.apply(PrettyParams.scala:204)
	at argonaut.PrettyParams$$anonfun$argonaut$PrettyParams$$trav$1$6$$anonfun$apply$3.apply(PrettyParams.scala:198)
	at scala.collection.TraversableOnce$$anonfun$foldLeft$1.apply(TraversableOnce.scala:157)
	at scala.collection.TraversableOnce$$anonfun$foldLeft$1.apply(TraversableOnce.scala:157)
	at scala.collection.immutable.HashMap$HashMap1.foreach(HashMap.scala:221)
	at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:428)
	at scala.collection.TraversableOnce$class.foldLeft(TraversableOnce.scala:157)
	at scala.collection.AbstractTraversable.foldLeft(Traversable.scala:104)
	at argonaut.PrettyParams$$anonfun$argonaut$PrettyParams$$trav$1$6.apply(PrettyParams.scala:198)
	at argonaut.PrettyParams$$anonfun$argonaut$PrettyParams$$trav$1$6.apply(PrettyParams.scala:197)
	at argonaut.Json$class.fold(Json.scala:34)
	at argonaut.JObject.fold(Json.scala:474)
	at argonaut.PrettyParams.argonaut$PrettyParams$$trav$1(PrettyParams.scala:178)
	at argonaut.PrettyParams.pretty(PrettyParams.scala:211)
	at argonaut.Json$class.nospaces(Json.scala:422)
	at argonaut.JObject.nospaces(Json.scala:474)
	at argonaut.Json$class.toString(Json.scala:464)
	at argonaut.JObject.toString(Json.scala:474)
	at com.our.code.SomeClass.serialize(SomeClass.scala:12)

Most of this StackTrace comes from Argonaut, a purely functional JSON parsing (If you don’t know Argonaut and have to do some JSON parsing in Scala, you should definitely check it out).

We were serializing a class to JSON and somewhere along the lines, a String is null. This was weird especially considering our class looked like this:

case class Foo(bar: Option[String], baz: Option[Int])

Not only that, but Argonaut handles options out of the box via it’s unique DSL for serialization:

EncodeJson(
  foo => {
    ("bar" :?= foo.bar) ->?:
    ("baz" :?= foo.baz) ->?:
    jEmptyObject
  }
)

Where :?= knows how to handle the Option[A] inside bar and baz. (Yes, I know there is shorter syntax for serialization, and yes I’m aware of argonaut-shapeless :), but for the sake of the example)

So WTF is going on? We don’t use null in our code, everything is perfectly wrapped in options, the mighty functional gods are happy, where is this coming from?

Some(null) is never EVER what you wanted

It took me a couple of hours to realize what was happening, and I have to say it did quite surprise me. Foo is the product of a Map[String, String] lookup. What we do prior to generating Foo is parse a String into key value pairs and then extract specific values which generate Foo.

A rough sketch of the code looks like this:

val s: String = ???
val kvps: scala.collection.immutable.Map[String, String] = parseLongStringIntoKeyValuePairs(a)

val foo = Foo(kvps.get("bar"), kvps.get("baz"))

If you’re familiar with Scalas immutable Map[A, B] you know that it’s get method returns an Option[B] (where B is the type of the value). The documentation looks like this:

abstract def get(key: K): Option[V]

Optionally returns the value associated with a key.
  key: the key value
  returns: an option value containing the value associated with key in this map, or None if none exists.

or None if none exists”. Ok, that makes sense, But what happens if null creeps in as a value? What would you expect the following to return?

val map: Map[String, String] = Map("bar" -> null)
map.get("bar")

If you guessed None, you’re wrong:

val map: Map[String, String] = Map("bar" -> null)
map.get("bar")

// Exiting paste mode, now interpreting.

map: Map[String,String] = Map(bar -> null)
res1: Option[String] = Some(null)

Some(null) - YAY! BEST OF BOTH WORLDS.

But just a minute ago I told you guys “We never use null, always Option[A]”. Was I lying to you? No, I wasn’t. The problem is that parseLongStringIntoKeyValuePairs is actally an interop with a Java library which parses the string and may definitely return null in the presence of a key with no value.

This feels weird though

This has been discussed many times in the Scala ecosystem. I guess the TLDR; is that Some(null) may actually convey something under specific contexts that None cannot, such as the existance of an empty value (where None may convey no value at all). This leads to a long phylosophical discussion about the meaning of null, None and the essence of the human race. Be it correct or not, this definitely gave me a good bite in the a** and something everyone should be aware of.

Fixing the problem

A quick fix for this problem is trivial. The first thing that comes to mind is Option.apply which handles null values gracefully by returning None:

map.get("bar").flatMap(Option.apply)

Wrapping up

Some(null) is evil sorcery IMO. I would never use it to convey the emptiness of an existing value, I can think of many other ways to encode such a value (I like to use the Foo.empty pattern when empty is a lazy val), especially when in the Scala ecosystem.

Of course a trivial unit test could of shown that this happens, but many times in Scala I have the warm feeling that Option[A] means “this can never be null”, but we should always keep in mind something like the above may happen.

Improving Spark Streaming Checkpointing Performance With AWS EFS

Update 10.03.2017 - There is a “gotcha” when using EFS for checkpointing which can be a deal breaker, pricing wise. I have updated the last part of the post (Downsides) to reflect the problem and a possible workaround.

Introduction

When doing stream processing with Apache Spark and in order to be a resilient to failures, one must provide a checkpointing endpoint for Spark to save it’s internal state. When using Spark with YARN as the resource manager one can easily checkpoint to HDFS. But what happens if we don’t need HDFS or want to use YARN? What if we’re using Spark Standalone as our cluster manager?

Up until recently, checkpointing to S3 was the de-facto storage for Spark with Standalone cluster manager running on AWS. This has certain caveats which make it problematic in a production environment (or any other environment, really). What I’d like to walk you through today is how I’ve experienced working with S3 as a checkpoint backing store in production and introduce you to an alternative approach using AWS EFS (Elastic File System). Lets get started.

Choosing a solution for checkpointing

Learning what checkpointing is and how it helps, you very soon realise that you need a distributed file system to store the intermediate data for recovery. Since our product was running on AWS, the viable solutions were:

  1. HDFS - Required us to install HDFS and run YARN as a cluster manager.
  2. Amazon S3 - Remote blob storage, pretty easy to set up and get running, low cost of maintenance.

Since we didn’t need (and want) an entire HDFS cluster only for the sake of checkpointing, so we decided on an S3 based solution that required us to provide an endpoint and credentials and we were on our way. It was quick, efficient and tied the loose end for us quickly.

When things start to go wrong

Rather quickly we started noticing Spark tasks failing with a “directory not found” due to S3s read-after-write semantics. Sparks checkpointing mechanism (to be precise, the underlying S3AFileSystem provided by Apache Hadoop) first writes all data to a temporary directory and only upon completion attempts to list the directory written to, making sure the folder exists, and only then it renames the checkpoint directory to its real name. Listing a directory after a PUT operation in S3 is eventually consistent per S3 documentation and would be the cause of sporadic failures which caused the checkpointing task to fail entirely.

This meant that:

  1. Checkpointing wasn’t 100% reliable, thus making driver recovery not reliable.
  2. Failed tasks accumulated in the Executor UI view, making it difficult to distinguish between random checkpoint failures and actual business logic failures.

And that was a problem. A better solution was needed that would give us reliable semantics of checkpointing.

Amazon Elastic File System

From the offical AWS documentation:

Amazon Elastic File System (Amazon EFS) provides simple, scalable file storage for use with Amazon EC2. With Amazon EFS, storage capacity is elastic, growing and shrinking automatically as you add and remove files, so your applications have the storage they need, when they need it.

Amazon Elastic File System (EFS) is a distributed file system which mounts onto your EC2 instances. It resembles Elastic Block Storage, but extends it in the way that allows to mount a single file system to multiple EC2 instances. This is a classic use case for Spark, since we need the same mount available a cross all Worker instances.

Setting up EFS

AWS docs lay out how one creates a EFS instace an attaches it to existing EC2 instances. The essential steps are:

  1. Step 2: Create Your Amazon EFS File System
  2. Step 3: Connect to Your Amazon EC2 Instance and Mount the Amazon EFS File System
  3. Network File System (NFS)–Level Users, Groups, and Permissions

One important thing I’d like to point out is permissions. Once you set up the EFS mount, automatically all permissions go to the root user. If your Spark application uses a different user to run under (which it usually does, under it’s own spark user) you have to remember to set permissions to that user on the checkpointing directory. You must make sure that user (be it spark or any other) has an identical userid and groupid under all EC2 instances. If you don’t, you’ll end up with permission denied errors while trying to write checkpoint data. If you have already set up an existing user and want to align all user and group ids to that user, read this tutorial on how that can be done (it’s pretty easy and straight forward).

Checkpoint directory format for StreamingContext

After the EFS mount is up an running on all EC2 instances, we need to pass the mounted directory to our StreamingContext. We do this by passing the exact location of the directory we chose with a file:// prefix. Assume that our mount location is /mnt/spark/:

val sc = new SparkContext("spark://127.0.0.1:7077", "EFSTestApp")
val streamingContext = StreamingContext.getOrCreate("file:///mnt/spark/", () => {
  new StreamingContext(sc, Seconds(4))
})

Spark will use LocalFileSystem from org.apache.hadoop.fs as the underlying file system for checkpointing.

Performance and stability gains

As you probably know, every Spark cluster has different size, different workloads and different computations being processed. There usually doesn’t exist a “when size fits all” solution. I encourage you to take this paragraph with a grain as salt as always with system performance.

Under AWS EFS, I’ve witnessed a x2 to x3 improvement in checkpointing times. Our streaming app checkpoints every 40 seconds, using 3 executors each with 14GB of memory and a constant message stream of ~ 2000-5000 messages/sec. Checkpointing took between 8-10 seconds on S3. Under EFS, that checkpointing time reduced to between 2-4 seconds for the same workload. Note that this will highly vary depending on your cluster setup, the size and count of each executor, number of Worker nodes.

Additionally, we now no longer experience failing tasks due to checkpointing errors, which is extremely important for fault tolerance of the streaming application:

Spark Streaming With No Task Failures

Downsides (Updated 10.3.2017)

There is an “issue” when running checkpointing with EFS which I’ve been hit in the face with in production. The way AWS EFS works is that your throughput for reads/writes is determined by the size of your filesystem. Initially, AWS gives you enough burst credits to be able to write at 100M writes/sec and starts studying your use of the filesystem (and taking away your credits). About a week later, it determines the pattern of use while checking how much you write and read and the size of the file system you have. If the file size is small, your IOPS get limited. More specifically, while using this solution only for checkpointing our streaming data (which varies between a couple of KBs to a couple of MBs), we were limited to 50K writes/second, which is definitely not enough and can cause your processing delay to increase substantually.

The (rather pricey) workaround for this issue is to make the file system large enough so that you constantly get 50M writes/second. To get this kind of throughput, you need your system to be at least 1TB in size. This can be an abitrary “junk” file just sitting around in the directory to increase it’s size. Price wise, this will cost you 300$ a month out of the box, without the additional price for the data you actually checkpoint. If this kind of price is a non issue for you then EFS will still be a nice solution.

Here is the size to throughput table AWS guarantees:

File System Size Aggregate Read/Write Throughput
A 100 GiB file system can… Burst to 100 MiB/s for up to 72 minutes each day, orDrive up to 5 MiB/s continuously
A 1 TiB file system can… Burst to 100 MiB/s for 12 hours each day, orDrive 50 MiB/s continuously
A 10 TiB file system can… Burst to 1 GiB/s for 12 hours each day, orDrive 500 MiB/s continuously
Generally, a larger file system can… Burst to 100MiB/s per TiB of storage for 12 hours each day, orDrive 50 MiB/s per TiB of storage continuously

I advise you to go through reading the performance section of the documentation before making a choice.

Wrapping up

Amazon Elastic File System is a relatively new but promising approach for Sparks checkpointing needs. It provides an elastic file system that can scale infinitely and be mounted to multiple EC2 instances easily and quickly. Moving away from S3 provided us with a stable file system to checkpoint our data, removing sporadic failures caused by S3’s eventual consistency model in respect to read-after-write.

I definitely recommend you to giving it a try.

Why You Might Be Misusing Sparks Streaming API

Disclaimer: Yes, I know the topic is controversial a bit, and I know most of this information is conveyed in Sparks documentation for it’s Streaming API, yet I felt the urge to write this piece after seeing this mistake happen many times over.

More often than not I see a question on StackOverflow from people who are new to Spark Streaming which look roughly like this:

Question: “I’m trying to do XYZ but it’s not working, what can I do? Here is my code:”

val sparkContext = new SparkContext("MyApp")
val streamingContext = new StreamingContext(sparkContext, Seconds(4))

val dataStream = streamingContext.socketTextStream("127.0.0.1", 1337)
dataStream.foreachRDD { rdd => 
  // Process RDD Here
}

Uhm, ok, what’s wrong with that?

When I started learning Spark my first landing point was an explanation about how RDDs (Resilient Distributed DataSets) work. The usual example was a word count where all the operations were performed on an RDD. I think it is safe to assume this is the entry point for many others who learn Spark (although today DataFrame\Sets are becoming the go to approach for beginners).

When one makes the leap to working with Spark Streaming, it may be a little bit unclear what the additional abstraction of a DStream means. This causes a lot of people to seek something they can grasp, and the most familiar method they encounter is foreachRDD, which takes an RDD as input and yields Unit (a result of a typical side effecting method). Then, they can again work on the RDD level which they already feel comfortable with and understand. That is missing the point of DStreams entirely, which is why I want to give a brief look into what we can do on the DStream itself without peeking into the underlying RDD.

Enter DStream

DStream is Sparks abstraction over micro-batches. It uses streaming sources, be that a network socket, Kafka or Kinesis (and the likes) providing us with a continuious flow of data that we read at every batch interval assigned to the StreamingContext.

In order to work with the DStream API, we must understand how the abstraction works. DStream is basically a sequence of RDDs. At a given batch interval a single RDD is consumed and passed through all the transformations we supply to the DStream. When we do:

val dataStream = streamingContext.socketTextStream("127.0.0.1", 1337)
dataStream
 .flatMap(_.split(" "))
 .filter(_ == "")
 .map((word, 1L))
 .count

That means we apply flatMap, filter, map and count onto the underlying RDD itself as well! There are at least as many of these transformations on DStream as there are for RDD, and these are the transformations we should be working with in our Streaming application. There is a comprehensive list of all the operations on the Spark Streaming documentation page under Transformations on DStreams

More operations on key value pairs

Similar to the PairRDDFunctions which brings in (implicitly) transformations on pairs inside an RDD, we have the equivalent PairDStreamFunctions with many such methods, primarly:

  • combineByKey - Combine elements of each key in DStream’s RDDs using custom functions.
  • groupByKey - Return a new DStream by applying groupByKey on each RDD
  • mapValues - Return a new DStream by applying a map function to the value of each key-value pairs in ‘this’ DStream without changing the key.
  • mapWithState - Return a MapWithStateDStream by applying a function to every key-value element of this stream, while maintaining some state data for each unique key.
  • reduceByKey - Return a new DStream by applying reduceByKey to each RDD. The values for each key are merged using the supplied reduce function. org.apache.spark.Partitioner is used to control the partitioning of each RDD.

And many more for you to enjoy and take advantage of.

Thats awesome! So why do I need foreachRDD at all?

Similar to RDDs, when Spark builds its graph of execution we distinguish between regular transformations and output transformations. The former are lazily evaluated when building the graph while the latter play a role in the materialization of the graph. If our DStream graph had only regular transformations applied to it, we would get an exception at runtime saying there’s no output transformation defined.

foreachRDD is useful when we’ve finished extracting and transforming our dataset, and we now want to load it to an external source. Let’s say I want to send transformed messages to RabbitMQ as part of my flow, I’ll iterate the underlying RDD partitions and send each message:

transformedDataStream.
  foreachRDD { rdd: RDD[String] =>
    val rabbitClient = new RabbitMQClient()
    rdd.foreachPartition { partition: Iterator[String] =>
      partition.foreach(msg => rabbitClient.send(msg))
    }
  } 

transformedDataStream is an arbitrary DStream after we’ve performed all our transformation logic on it. The result of all these transformations a DStream[String]. Inside foreachRDD, we get a single RDD[String] where we then iterate each of it’s partitions creating a RabbitMQClient to send each message inside the partition iterator.

There are several more of these output transformations listed on the Spark Streaming documentation page which are very useful.

Wrapping up

Spark Streamings DStream abstraction provides powerfull transformation for processing data in a streaming fashion. When we do stream processing in Spark, we’re processing many individual micro-batched RDDs which we can reason about in our system flowing one after the ever. When we apply transformations on the DStream it percolates all the way down to each RDD that is passed through without us needing to apply the transformations on it by ourselves. Finally, the use of foreachRDD should be kept to when we want to take of our transformed data and perform some side effecting operation to it, mostly things like sending data over the wire to a database, pub-sub and the likes. Use it wisely and only when you truely need to!

Diving Into Spark 2.1.0 Blacklisting Feature

Disclaimer: What I am about to talk about is an experimental Spark feature. We are about to dive into an implementation detail, which is subject to change at any point in time. It is an advantage if you have prior knowledge of how Sparks scheduling works, although if you don’t I will try to lay that out throughout this post, so don’t be afraid :).

Introduction

Spark 2.1.0 comes with a new feature called “blacklist”. Blacklisting enables you to set threshholds on the number of failed tasks on each executor and node, such that a task set or even an entire stage will be blacklisted for those problematic units.

The basics - Jobs, stages and tasks

When we create a Spark graph, the DAGScheduler takes our logical plan (or RDD linage) which is composed of transformations and translates them into a physical plan. For example, let’s take the classic word count:

val textFile = sc.parallelize(Seq("hello dear readers", "welcome to my blog", "hope you have a good time"))
val count = textFile.flatMap(line => line.split(" "))
                 .map(word => (word, 1))
                 .reduceByKey(_ + _)
                 .count

println(count)

We have three transformations operating on the text file: flatMap, map and reduceByKey (lets ignore parallalize for the moment). The first two are called “narrow transformations” and can be executed sequentially one after the other as all the data is available locally for the given partition. reduceByKey is called a wide transformation because it requires a shuffle of the data to be able to reduce all elements with the same key together (If you want a lengthier explanation of narrow vs wide transformations, see this blog post by Ricky Ho)

Spark creates the following physical plan from this code:

Spark Physical Plan

As you can see, the first stage contains three tasks: parallalize, flatMap and map and the second stage has one task, reduceByKey. Stages are bounded by wide transformations, and that is why reduceByKey is started as part of stage 1, followed by count.

We have established that a job is:

  1. Conceived of one or more stages.
  2. Each stage has a set of tasks, bounded by wide transformations.
  3. Each task is executed on a particular executor in the Spark cluster.

As an optimization, Spark takes several narrow transformations that can run sequentially and executes them together as a task set, saving us the need to send intermediate results back to the driver.

Creating a scenario

Lets imagine we’ve been assigned a task which requires us to fetch a decent amount of data over the network, do some transformations and then save the output of those transformations to the database. We design the code carefully and finally create a spark job, which does exactly what we want and publish that job to the Spark cluster to begin processing. Spark builds the graph and dispatches the individual tasks to the available executors to do the processing. We see that everything works well and we deploy our job to production.

We’ve happily deployed our job which is working great. After a few days running we start noticing a particular problem with one of the nodes in the cluster. We notice that every time that node has to execute the fetching of the data from our external service, the task fails with a network timeout, eventually causing the entire stage to fail. This job is mission critical to our company and we cannot afford to stop it and let everything wait. What do we do??

Enter Blacklisting

Spark 2.1.0 enables us to blacklist a problematic executor and even an entire node (which may contain one to N executors) from receiving a particular task, task set or whole stage. In our example, we saw that a faulty node was causing tasks to fail, and we want to do something about it. Let us see how this new feature can help.

If we take a look at the Spark Configuration section of the documentation, we see all the settings we can tune:

Flag Default Description
spark.blacklist.enabled false If set to “true”, prevent Spark from scheduling tasks on executors that have been blacklisted due to too many task failures. The blacklisting algorithm can be further controlled by the other “spark.blacklist” configuration options.
spark.blacklist.task.maxTaskAttemptsPerExecutor 1 (Experimental) For a given task, how many times it can be retried on one executor before the executor is blacklisted for that task.
spark.blacklist.task.maxTaskAttemptsPerNode 2 (Experimental) For a given task, how many times it can be retried on one node, before the entire node is blacklisted for that task.
spark.blacklist.stage.maxFailedTasksPerExecutor 2 (Experimental) How many different tasks must fail on one executor, within one stage, before the executor is blacklisted for that stage.
spark.blacklist.stage.maxFailedExecutorsPerNode 2 (Experimental) How many different executors are marked as blacklisted for a given stage, before the entire node is marked as failed for the stage.

We can select both the number of attempts made for each task both for an executor or node, and more importantly we can mark how many times we want to allow a particular task to fail on a single executor before mark it as blacklisted, and how many executors can fail on a given node before that node is completely blacklisted. Marking a node as blacklisted means that the entire stage of the underlying task may never run again on that particular executor/node for the entire lifetime of our job.

The algorithm

After we understood the basic configurations of blacklisting, let us look at the code. In order to do that, we need a little background on how task scheduling works in Spark. This is an interesting topic and quite complex, so I will try to skim through without going into too much detail (if you are interested in more detail you can find it in Jacek Laskowskis excellent “Mastering Spark” gitbook).

Sparks scheduling model is similar to the one of Apache Mesos, where each executor offers its resources, and the scheduling algorithm selects which node gets to run each job and when.

Let us explore how a single job gets scheduled. We’ll explore the operations starting the DAGScheduler and below, there are more operations which are invoked by the SparkContext that are less relevant:

  1. Everything starts with the famous DAGScheduler which builds the DAG and emits stages.
  2. Each stage in turn is disassembled to tasks (or to be more accurate, a set of tasks). A TaskSetManager is created for each task set. This manager is going to be in charge of the set of tasks throughout the lifetime of its execution, including re-scheduling on failure or aborting in case something bad happens.
  3. After the TaskSetManager is created, the backend creates work offers for all the executors and calls the TaskScheduler with these offers.
  4. TaskScheduler receives the work offers of the executors and iterates its TaskSetManagers and attempts to schedule the task sets on each of the available executor resources (remember this, we’ll come back to this soon).
  5. After all work has been assigned to the executors, the TaskScheduler notifies the CoarseGrainedSchedulerBackend which then serializes each task and send them off to the executors.

If you want to follow this flow in the code base that is a little complicated, the hierarchy looks roughly like this:

DAGScheduler.handleJobSubmitted -> DAGScheduler.submitStage -> DAGScheduler.submitMissingStages -> TaskScheduler.submitTasks -> CoarseGrainedSchedulerBackend.reviveOffers -> CoarseGrainedSchedulerBackend.makeOffers -> TaskScheduler.resourceOffers -> TaskScheduler.resourceOfferSingleTaskSet -> TaskSetManager.resourceOffer -> TaskScheduler.resourceOfferSingleTaskSet CoarseGrainedSchedulerBackend.launchTasks.

There is lots of bookeeping going on between these calls, but I just wanted to outline the high level orchestration of how tasks gets scheduled, I hope you’ve managed to follow. If not, don’t worry, just remember that when tasks are scheduled their owner is an object named TaskSetManager which controls everything related to the set of tasks. This object has an important method called resourceOffer, which we are going to look into.

TaskSetManager, TaskSetBlacklist and the blacklist bookkeeping

A new class was introduced in 2.1.0 called TaskSetBlacklist. It is in charge of bookeeping failed tasks. This class reads the configurations set by the users via the spark.blacklisting.* flags, and is consumed internally by each TaskSetManager keeping track of additional information:

private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR)
private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE)
private val MAX_FAILURES_PER_EXEC_STAGE = conf.get(config.MAX_FAILURES_PER_EXEC_STAGE)
private val MAX_FAILED_EXEC_PER_NODE_STAGE = conf.get(config.MAX_FAILED_EXEC_PER_NODE_STAGE)

The mappings that keep track of failures look like this:

private val nodeToExecsWithFailures = new HashMap[String, HashSet[String]]()
private val nodeToBlacklistedTaskIndexes = new HashMap[String, HashSet[Int]]()
val execToFailures = new HashMap[String, ExecutorFailuresInTaskSet]()
private val blacklistedExecs = new HashSet[String]()
private val blacklistedNodes = new HashSet[String]()

As you can see, there are three maps and two hash sets:

  1. Node -> Executor Failures: Maps from a node to all its executors that have failed a task,
  2. Node -> Blacklisted Task Indexes: Indexes of task ids that have been blacklisted on that particular node,
  3. Executor -> Task Set Failure Count: Maps an executor to all it’s failures of a single task set.
  4. Black listed executors
  5. Black listed nodes

TaskSetBlacklist exposes a bunch of utility methods for the consumption from the TaskSetManager holding it. For example:

/**
* Return true if this executor is blacklisted for the given stage.  Completely ignores
* anything to do with the node the executor is on.  That
* is to keep this method as fast as possible in the inner-loop of the scheduler, where those
* filters will already have been applied.
*/
def isExecutorBlacklistedForTaskSet(executorId: String): Boolean = {
    blacklistedExecs.contains(executorId)
}

The real interesting method is the one in charge of updating task status upon failure, called updateBlacklistForFailedTask. This method is invoked by the TaskSetManager when the TaskScheduler signals a failed task:

private[scheduler] def updateBlacklistForFailedTask(
    host: String,
    exec: String,
    index: Int): Unit = {
  val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host))
  execFailures.updateWithFailure(index)

  // check if this task has also failed on other executors on the same host -- if its gone
  // over the limit, blacklist this task from the entire host.
  val execsWithFailuresOnNode = nodeToExecsWithFailures.getOrElseUpdate(host, new HashSet())
  execsWithFailuresOnNode += exec
  val failuresOnHost = execsWithFailuresOnNode.toIterator.flatMap { exec =>
  execToFailures.get(exec).map { failures =>
      // We count task attempts here, not the number of unique executors with failures.  This is
      // because jobs are aborted based on the number task attempts; if we counted unique
      // executors, it would be hard to config to ensure that you try another
      // node before hitting the max number of task failures.
      failures.getNumTaskFailures(index)
    }
  }.sum
  if (failuresOnHost >= MAX_TASK_ATTEMPTS_PER_NODE) {
    nodeToBlacklistedTaskIndexes.getOrElseUpdate(host, new HashSet()) += index
  }

  // Check if enough tasks have failed on the executor to blacklist it for the entire stage.
  if (execFailures.numUniqueTasksWithFailures >= MAX_FAILURES_PER_EXEC_STAGE) {
    if (blacklistedExecs.add(exec)) {
      logInfo(s"Blacklisting executor ${exec} for stage $stageId")
      // This executor has been pushed into the blacklist for this stage.  Let's check if it
      // pushes the whole node into the blacklist.
      val blacklistedExecutorsOnNode =
        execsWithFailuresOnNode.filter(blacklistedExecs.contains(_))
      if (blacklistedExecutorsOnNode.size >= MAX_FAILED_EXEC_PER_NODE_STAGE) {
        if (blacklistedNodes.add(host)) {
          logInfo(s"Blacklisting ${host} for stage $stageId")
        }
      }
    }
  }
}

Breaking down the execution flow:

  1. Update the count of failures for this executor for the given task id.
  2. Check if there were multiple failures of this task by other executors on the same node, if there were and we’ve exceeded the MAX_TASK_ATTEMPTS_PER_NODE the entire node is blacklisted for this particular task index.
  3. Check if this failure means we’ve exceeded the allowed number of failures for the entire stage on the given executor. If we have mark the entire stage as blacklisted for the executor.
  4. Check if this failure means we’ve exceeded the number of allowed executor failures for the node, If we have, blacklist the node for the entire stage.

The execution flow is pretty clear, we start from the smallest unit which is a single task, and end up checking the entire stage on the node which executed the task set.

We’ve now seen where the TaskSetManager updates internal status regarding the task execution, but when (and where) does it check if a particular task can be scheduled on a given executor/node? It does so exactly when the TaskScheduler asks it to assign a WorkOffer to an executor, inside the resourceOffer method call:

/**
* Respond to an offer of a single executor from the scheduler by finding a task
*
* NOTE: this function is either called with a maxLocality which
* would be adjusted by delay scheduling algorithm or it will be with a special
* NO_PREF locality which will be not modified
*
* @param execId the executor Id of the offered resource
* @param host  the host Id of the offered resource
* @param maxLocality the maximum locality we want to schedule the tasks at
*/
@throws[TaskNotSerializableException]
def resourceOffer(
    execId: String,
    host: String,
    maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] = {
  val offerBlacklisted = taskSetBlacklistHelperOpt.exists { blacklist =>
    blacklist.isNodeBlacklistedForTaskSet(host) ||
    blacklist.isExecutorBlacklistedForTaskSet(execId)

  if (!isZombie && !offerBlacklisted) {
      var allowedLocality = maxLocality

      if (maxLocality != TaskLocality.NO_PREF) {
        allowedLocality = getAllowedLocalityLevel(curTime)
        if (allowedLocality > maxLocality) {
          // We're not allowed to search for farther-away tasks
          allowedLocality = maxLocality
        }
      }

      dequeueTask(execId, host, allowedLocality).map { case ((index, taskLocality, speculative)) => { 
          // Do task queueing stuff.
      }
  }
}

taskSetBlacklistHelperOpt is a Option[TaskSetBlacklist] instance which is only set to Some[TaskSetBlacklist] if the flag was enabled in the configuration. Prior to assinging an offer to an executor, the TaskSetManager checks to see if the host/executor is blacklisted for the task set, if it isn’t, it returns an Option[TaskDescription], assigning the executor this particular task set.

However, there needs to be an additional check. As we’ve seen, we don’t only blacklist an entire task set, we also blacklist indiviudal tasks from running on an executor/node. For that, there is an additional call inside a method called dequeueTask to isTaskBlacklistedOnExecOrNode which checks if the task is blacklisted to run on the executor or node. If it is, attempts to schedule it on the next executor.

What happends if all nodes are blacklisted?

Good question! Spark has an additional validation method inside TaskScheduler.resourceOffers which kicks in if none of the tasks have been scheduled to run. This can happen when all nodes are blacklisted:

if (!launchedAnyTask) {
  taskSet.abortIfCompletelyBlacklisted(hostToExecutors)
}

I won’t go into the implementation of the method, which you can find here. This method validates that all the tasks actually can’t run due to blacklisting, and if it finds out they can’t, it aborts the task set.

Wrapping up

Spark 2.1.0 brings a new ability to configure blacklisting of problematic executors and even entire nodes. This feature can be useful when one is experiencing problems with a particular executor, such as network problems. If you are having failures at the node level, such as disk filled up which is failing your tasks, it is possible to block entire nodes from recieving task sets and entire stages.

The main players in this blacklisting games are the TaskSetManager, responsible for a given task set, and its TaskSetBlacklist instance which handles all the bookkeeping data on who failed where. Together they are consumed by the TaskScheduler which is the one in charge of the actual scheduling of the tasks.

I hope I’ve managed to explain to gory details of the implementation in an approachable way and that you have a basic understanding of what is going on under the covers.