Persistence and Caching in Apache Spark
When you run transformations on RDDs in Apache Spark, they are not immediately computed. Instead, Spark builds a Directed Acyclic Graph (DAG) of transformations. Only when an action (like count()
, collect()
, etc.) is called does Spark execute the transformations to return a result.
But what happens when you need to reuse an RDD multiple times? By default, Spark will recompute the RDD from scratch every time it's needed — which can be inefficient for large datasets. This is where persistence and caching come in.
What is Caching?
Caching tells Spark to keep the data in memory after the first time it is computed. If the same RDD is used again, Spark will serve it from memory rather than recomputing it.
What is Persistence?
Persistence is a more flexible version of caching. It allows you to choose where and how the RDD should be stored (memory, disk, both, or even off-heap).
Why Use Caching or Persistence?
- To avoid recomputing expensive transformations
- To speed up iterative algorithms like PageRank or ML training
- To reduce load on data sources (e.g., if the data is read from a database)
PySpark Example: Caching an RDD
Let’s see an example of caching in PySpark. We'll create an RDD, cache it, and perform actions on it.
from pyspark.sql import SparkSession
# Create a SparkSession
spark = SparkSession.builder.master("local").appName("CachingExample").getOrCreate()
# Create an RDD from a range
rdd = spark.sparkContext.parallelize(range(1, 1000000))
# Define an expensive transformation
squared_rdd = rdd.map(lambda x: x * x)
# Cache the RDD
squared_rdd.cache()
# Perform two actions on the cached RDD
print("First count (triggers computation):", squared_rdd.count())
print("Second count (uses cached result):", squared_rdd.count())
spark.stop()
First count (triggers computation): 999999 Second count (uses cached result): 999999
How Does This Work?
The first time squared_rdd.count()
is executed, Spark computes the squares and stores the result in memory. When the second action is executed, Spark skips recomputation and directly returns the result from memory.
Question:
What if your RDD is too large to fit in memory?
Answer:
In that case, Spark will recompute the partitions that don't fit, unless you use persist(StorageLevel.MEMORY_AND_DISK)
, which allows Spark to spill the data to disk when memory is full.
Persistence Storage Levels
Spark offers several storage levels via the persist()
method:
MEMORY_ONLY
– Store RDD as deserialized objects in memory (default forcache()
)MEMORY_AND_DISK
– Spill to disk if not enough memoryDISK_ONLY
– Store only on disk (useful for very large RDDs)MEMORY_ONLY_SER
– Store as serialized Java objects (uses less memory)
PySpark Example: Persistence
Here's how to use persist()
to store an RDD in memory and disk:
from pyspark import StorageLevel
# Recreate Spark session
spark = SparkSession.builder.master("local").appName("PersistenceExample").getOrCreate()
rdd = spark.sparkContext.parallelize(range(1, 1000000))
cubic_rdd = rdd.map(lambda x: x ** 3)
# Persist with memory and disk
cubic_rdd.persist(StorageLevel.MEMORY_AND_DISK)
# Trigger computation
print("Sum of cubes:", cubic_rdd.sum())
# Second action uses persisted result
print("Max cube:", cubic_rdd.max())
spark.stop()
Sum of cubes: 249999500000166666500000 Max cube: 999999000000999999
Important Notes
- Always unpersist() RDDs if they are no longer needed to free up memory.
- Use caching only when the RDD will be reused multiple times in a job.
- If data fits in memory, caching gives the best performance. Otherwise, use persistence with disk spillover.
Summary
Persistence and caching are optimization techniques in Spark that reduce the time and resources needed for repeated computations. For data pipelines or iterative jobs, using them wisely can make your Spark applications faster and more efficient.