What Are Window Functions?
Window functions in Apache Spark are special SQL functions that perform calculations across a set of rows that are related to the current row. These rows are called a "window" of rows. Unlike aggregate functions which return a single result for a group, window functions return a result for every row while still considering other rows in the group.
Why Use Window Functions?
- To calculate running totals
- To find ranks within partitions
- To access previous or next values (lead/lag)
- To perform grouped operations without collapsing rows
Use Case Example: Sales by Employee
Imagine you have a dataset that contains daily sales data for employees across departments. You want to:
- Rank employees within each department based on their sales
- Calculate the running total of sales for each employee
- Compare each employee’s sales with the previous day
Step-by-Step with PySpark
Let’s create and process this data using PySpark’s window functions.
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import col, rank, sum, lag
# Create Spark session
spark = SparkSession.builder.master("local").appName("Window Functions").getOrCreate()
# Sample sales data
data = [
("Alice", "Electronics", "2024-01-01", 200),
("Alice", "Electronics", "2024-01-02", 180),
("Bob", "Electronics", "2024-01-01", 300),
("Bob", "Electronics", "2024-01-02", 260),
("Charlie", "Clothing", "2024-01-01", 150),
("Charlie", "Clothing", "2024-01-02", 170),
]
columns = ["employee", "department", "date", "sales"]
df = spark.createDataFrame(data, columns)
df.show()
+--------+------------+----------+-----+ |employee| department | date |sales| +--------+------------+----------+-----+ | Alice |Electronics |2024-01-01| 200| | Alice |Electronics |2024-01-02| 180| | Bob |Electronics |2024-01-01| 300| | Bob |Electronics |2024-01-02| 260| |Charlie | Clothing |2024-01-01| 150| |Charlie | Clothing |2024-01-02| 170| +--------+------------+----------+-----+
1. Ranking Within Each Department
Let’s rank employees based on total sales in each department:
from pyspark.sql.functions import sum as _sum
# Window partitioned by department and ordered by total sales descending
windowDept = Window.partitionBy("department").orderBy(col("sales").desc())
df.withColumn("rank", rank().over(windowDept)).show()
+--------+------------+----------+-----+----+ |employee| department | date |sales|rank| +--------+------------+----------+-----+----+ | Bob |Electronics |2024-01-01| 300| 1 | | Bob |Electronics |2024-01-02| 260| 2 | | Alice |Electronics |2024-01-01| 200| 3 | | Alice |Electronics |2024-01-02| 180| 4 | |Charlie | Clothing |2024-01-02| 170| 1 | |Charlie | Clothing |2024-01-01| 150| 2 | +--------+------------+----------+-----+----+
Question:
How is this different from GROUP BY?
Answer:
GROUP BY collapses rows into one summary row per group. Window functions keep all rows intact while adding calculated values as new columns.
2. Running Total of Sales per Employee
Now let’s calculate a cumulative sum (running total) of sales for each employee ordered by date.
windowEmp = Window.partitionBy("employee").orderBy("date")
df.withColumn("running_total", sum("sales").over(windowEmp)).show()
+--------+------------+----------+-----+-------------+ |employee| department | date |sales|running_total| +--------+------------+----------+-----+-------------+ | Alice |Electronics |2024-01-01| 200| 200 | | Alice |Electronics |2024-01-02| 180| 380 | | Bob |Electronics |2024-01-01| 300| 300 | | Bob |Electronics |2024-01-02| 260| 560 | |Charlie | Clothing |2024-01-01| 150| 150 | |Charlie | Clothing |2024-01-02| 170| 320 | +--------+------------+----------+-----+-------------+
Question:
Why do we use partitionBy("employee")
here?
Answer:
Because we want to reset the running total for each employee. If we didn’t, Spark would calculate a running total across all rows, mixing different employees.
3. Comparing With Previous Day’s Sales (Lag Function)
Let’s find the difference in sales compared to the previous day for each employee.
from pyspark.sql.functions import lag
df.withColumn("prev_day_sales", lag("sales", 1).over(windowEmp)).show()
+--------+------------+----------+-----+---------------+ |employee| department | date |sales|prev_day_sales | +--------+------------+----------+-----+---------------+ | Alice |Electronics |2024-01-01| 200| null | | Alice |Electronics |2024-01-02| 180| 200 | | Bob |Electronics |2024-01-01| 300| null | | Bob |Electronics |2024-01-02| 260| 300 | |Charlie | Clothing |2024-01-01| 150| null | |Charlie | Clothing |2024-01-02| 170| 150 | +--------+------------+----------+-----+---------------+
The lag
function lets you peek at previous rows. Similarly, lead
can access future rows.
Summary
Window functions are powerful tools in Spark that allow you to perform advanced operations like ranking, cumulative sums, and row comparisons — all while retaining row-level detail. They are essential for data cleaning, analysis, and preparation in real-world pipelines.