We all know unit testing our code is a good idea. But there is some kind of inertia that keeps most of us from doing it. Part of this inertia, I believe is not knowing how to start. So I am sharing how I write tests for pyspark in the hopes that it will give you the push you need.
Testing in python
First things first, if you’ve never written a test for a normal python function we need to get that out of the way first. Here is a minimal example:
def add(x: int, y: int) -> int:
return x + y
def test_add():
inputs = {"x": 3, "y": 4}
expected = 7
got = add(**inputs)
assert got == expected
Save to my_test.py
then run pytest my_test.py
, there you go! Now Back to
pyspark.
Basic Script pattern
Let’s see an example pyspark script for generating a denormalized sales table, that is, we’ll add some user information to each transaction line.
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local").getOrCreate()
sales = spark.read.csv("./sales.csv", sep="\t", header=True)
users = spark.read.csv("./user.csv", sep="\t", header=True)
df = sales.join(users, on=["user_id"], how="left")
df.write.mode("overwrite").csv("./full_sales.csv", sep="\t", header=True)
I believe that 50% of writing good tests is actually writing easily testable code. So lets restructure the code above a bit.
The first thing I want to know, when trying to understand a new spark script, is what are the inputs and what is the output, so I like to keep that information right at the top.
TARGET = "./full_sales.csv"
SOURCES = {
"sales": "./sales.csv",
"users": "./user.csv",
}
Next, although getting the correct read and write parameters can be a bit tricky
at first, its probably not where the bulk of the complexity of your spark job
will be. For this reason I like to have a transform(...)
function exclusively
responsible for taking dataframes in and returning a dataframe out, without
worrying about the format on disk.
def transform(sales: DataFrame, users: DataFrame) -> DataFrame:
df = sales.join(users, on=["user_id"], how="left")
return df
In this case our transform function is quite trivial but we’ll discuss later on why I believe its still worth testing.
The full restructured script would look something like this:
TARGET = "./full_sales.csv"
SOURCES = {
"sales": "./sales.csv",
"users": "./user.csv",
}
def main(spark, sources, target):
inputs = {
k: spark.read.csv(v, sep="\t", header=True)
for k, v in sources.items()
}
df = transform(**inputs)
df.write.mode("overwrite").csv(target, sep="\t", header=True)
def transform(sales: DataFrame, users: DataFrame) -> DataFrame:
df = sales.join(users, on=["user_id"], how="left")
return df
if __name__ == "__main__":
spark = SparkSession.builder.master("local").getOrCreate()
main(spark, SOURCES, TARGET)
With this we can easily swap inputs and outputs, and we can also test the transformation without having to worry about file formats.
Testing in this pattern
So now that we refactored the code, how can we write a unit test for it? We’ll rely on the idea that we can create a dataframe from a tuple of tuples. Like this:
data = (
("transaction1", 10.0, "user1"),
("transaction2", 20.0, "user2"),
)
schema = "transaction_id: string, price: double, user_id: string"
df = spark.createDataFrame(data, schema=schema)
Now that we can create example dataframes, testing pyspark becomes just like testing any other function! We create a sample input and an expected output and compare the actual output with what we expect.
def test_transform():
spark = SparkSession.builder.master("local").getOrCreate()
sales = (
("transaction1", 10.0, "user1"),
("transaction2", 20.0, "user2"),
)
schema = "transaction_id: string, price: double, user_id: string"
sales = spark.createDataFrame(sales, schema=schema)
users = (
("user1", "Jonh"),
("user2", "Rebeca"),
)
schema = "user_id: string, name: string"
users = spark.createDataFrame(users, schema=schema)
expected = (
("transaction1", 10.0, "user1", "John"),
("transaction2", 20.0, "user2", "Rebeca"),
)
schema = "transaction_id: string, name: string"
expected = spark.createDataFrame(expected, schema=schema)
got = transform(sales, users)
assert_dataframe_equal(got, expected)
I hid a bit of dark magic inside assert_dataframe_equal(...)
to make it
simpler to read. We will come back to this after we cover a few other points.
Things to keep in mind
From the description above I make it sound like testing pyspark is just like testing python. But if you’ve tried this before you’ve probably realized that there are a few caveats that you should pay attention to. Here are a couple of them.
Shuffle partitions
If you’ve been using spark for a while, you’ll know that every time that you
do a join
or a group by
spark has to do a “shuffle” to position data from
the same Keys in the same machine. This is the famous “reduce” in the Map Reduce
paradigm. If you have a very big dataset and use the default 200
for the number of shuffle partitions each of the 200 pieces will still be very big
and your data will spill to disk which will make your job much slower. So usually
when you want to change shuffle partitions you want to make it bigger.
In the case of unit-tests its the oposite your dataset has just a few lines. If you use the default most of the partitions will be empty, but they will be created and require coordination from the scheduler. This makes our test run slower than it has to. To avoid this, we set the shuffle partitions to 1 for the tests.
spark.conf.set("spark.sql.shuffle.partitions", 1)
Even with this optimization, starting up a spark session still takes a few seconds, so your pyspark tests will take a bit longer than you normal python tests. :/
Minimal test schema
Another point of friction is that datalake tables tend to be very wide to avoid joins. Writing out all of the columns for a test case can be very demotivating.
# sales schema
transaction_id sku price user_id date country
# users schema
user_id name e-mail favorite-color nick-name
# output schema
transaction_id sku price user_id date country name e-mail favorite-color nick-name
In our test case actually uses the user_id
column so we can write our tests
using only that column. Well, actually we need an extra column from each table
to be able to verify the join.
# sales schema
transaction_id user_id
# users schema
user_id name
# output schema
transaction_id user_id name
For the output schema it can be restricted to only the columns you want to check. So if in another test we want to check if the price was correctly converted into dollars, then we no longer need to check the user_id.
# output schema
transaction_id usd_pricer
Compare function
Comparing spark dataframes can be a bit tricky. There are many ways two dataframes can be different. They can have different columns, different column types or different column and they can differ in value. The very least that we need is that any one of these errors triggers an assert failure. We can do this by ordering the dataframes then collecting them.
def assert_dataframes_equal(df, df_ref):
cols = df_ref.columns
# order datasets
actual_df = df.select(cols).orderBy(cols)
expected_df = df_ref.orderBy(cols)
if actual_df.collect() != expected_df.collect():
assert False
assert True
This function will serve its purpose, but when the assertion fails you will be left with some very unhelpful error messages. This is very bad for testing because if you need to push a new feature and you can’t figure out what is failing in the tests, that can be very frustrating.
I have slowly built up an assert function that gives help full messages upon failures. The full code is on github. I intend on creating a small library with it when I have the time. Here is rough sketch of the extra steps for better messages.
def assert_dataframes_equal(df, df_ref, order_cols=None):
cols = df_ref.columns
missing_cols = set(cols) - set(df.columns)
if missing_cols:
print_diff_cols(df_ref, df)
assert False
assert_column_types(df_ref, df)
if not order_cols:
order_cols = cols
# order datasets
actual_df = df.select(cols).orderBy(order_cols)
expected_df = df_ref.orderBy(order_cols)
actual = actual_df.collect()
expected = expected_df.collect()
if actual != expected:
print_full_diff(expected_df, actual_df)
assert False
assert True
Excuses for not writing pyspark tests
Finally, lets go over a few excuses that might keep you from testing your code and try to debunk them real quick.
“I prefer sql, pyspark is too verbose”
For those of you who prefer to write in sql, here is a way to still do it:
def transform(sales: DataFrame, users: DataFrame) -> DataFrame:
sales.createOrReplaceTempView("sales")
users.createOrReplaceTempView("users")
query = """
select * from sales s left join users on s.user_id = u.users_id
"""
df = spark.sql(query)
return df
I realize the irony that this is more verbose than the pure pyspark version. But hey, its in sql and now you can test it! :)
“Its just a join”
I’ve heard this a lot: “its too simple to test”. Even in the simplest cases you’d be surprised at how many times there can be a bug. But a better argument for testing the simple cases is that they inevitably will become more complex over time, and when it does, your test is already setup for you to extend it.
In our sales example we will soon receive the request to convert different currencies into dollars or discount different country taxes.
def transform(sales, users, exchange_rates, tax_rates):
# exchange rates
df = sales.join(exchange_rates, on=['date', 'currency'], how="left")
df = df.withColumn("usd_price", df.price * df.exchange_rate)
# taxes
df = df.join(tax_rates, on=['country'], how="left")
df = df.withColumn("taxes", df.usd_price * df.tax_rate)
df = sales.join(users, on=["user_id"], how="left")
return df
That’s it, Happy testing !!!
Full code at: testing-pyspark