We all know unit testing our code is a good idea. But there is some kind of inertia that keeps most of us from doing it. Part of this inertia, I believe is not knowing how to start. So I am sharing how I write tests for pyspark in the hopes that it will give you the push you need.

Testing in python

First things first, if you’ve never written a test for a normal python function we need to get that out of the way first. Here is a minimal example:

def add(x: int, y: int) -> int:
    return x + y


def test_add():
    inputs = {"x": 3, "y": 4}
    expected = 7
    got = add(**inputs)
    assert got == expected

Save to my_test.py then run pytest my_test.py, there you go! Now Back to pyspark.

Basic Script pattern

Let’s see an example pyspark script for generating a denormalized sales table, that is, we’ll add some user information to each transaction line.

from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local").getOrCreate()

sales = spark.read.csv("./sales.csv", sep="\t", header=True)
users = spark.read.csv("./user.csv", sep="\t", header=True)

df = sales.join(users, on=["user_id"], how="left")

df.write.mode("overwrite").csv("./full_sales.csv", sep="\t", header=True)

I believe that 50% of writing good tests is actually writing easily testable code. So lets restructure the code above a bit.

The first thing I want to know, when trying to understand a new spark script, is what are the inputs and what is the output, so I like to keep that information right at the top.

TARGET = "./full_sales.csv"

SOURCES = {
  "sales": "./sales.csv",
  "users": "./user.csv",
}

Next, although getting the correct read and write parameters can be a bit tricky at first, its probably not where the bulk of the complexity of your spark job will be. For this reason I like to have a transform(...) function exclusively responsible for taking dataframes in and returning a dataframe out, without worrying about the format on disk.

def transform(sales: DataFrame, users: DataFrame) -> DataFrame:
    df = sales.join(users, on=["user_id"], how="left")
    return df

In this case our transform function is quite trivial but we’ll discuss later on why I believe its still worth testing.

The full restructured script would look something like this:

TARGET = "./full_sales.csv"

SOURCES = {
  "sales": "./sales.csv",
  "users": "./user.csv",
}

def main(spark, sources, target):
    inputs = {
        k: spark.read.csv(v, sep="\t", header=True)
        for k, v in sources.items()
    }

    df = transform(**inputs)

    df.write.mode("overwrite").csv(target, sep="\t", header=True)
    

def transform(sales: DataFrame, users: DataFrame) -> DataFrame:
    df = sales.join(users, on=["user_id"], how="left")
    return df

if __name__ == "__main__":
    spark = SparkSession.builder.master("local").getOrCreate()
    main(spark, SOURCES, TARGET)

With this we can easily swap inputs and outputs, and we can also test the transformation without having to worry about file formats.

Testing in this pattern

So now that we refactored the code, how can we write a unit test for it? We’ll rely on the idea that we can create a dataframe from a tuple of tuples. Like this:

data = (
    ("transaction1", 10.0, "user1"),
    ("transaction2", 20.0, "user2"),
)
schema = "transaction_id: string,  price: double, user_id: string"
df = spark.createDataFrame(data, schema=schema)

Now that we can create example dataframes, testing pyspark becomes just like testing any other function! We create a sample input and an expected output and compare the actual output with what we expect.

def test_transform():
    spark = SparkSession.builder.master("local").getOrCreate()

    sales = (
        ("transaction1", 10.0, "user1"),
        ("transaction2", 20.0, "user2"),
    )
    schema = "transaction_id: string,  price: double, user_id: string"
    sales = spark.createDataFrame(sales, schema=schema)

    users = (
        ("user1", "Jonh"),
        ("user2", "Rebeca"),
    )
    schema = "user_id: string, name: string"
    users = spark.createDataFrame(users, schema=schema)

    expected = (
        ("transaction1", 10.0, "user1", "John"),
        ("transaction2", 20.0, "user2", "Rebeca"),
    )
    schema = "transaction_id: string, name: string"
    expected = spark.createDataFrame(expected, schema=schema)

    got = transform(sales, users)

    assert_dataframe_equal(got, expected)

I hid a bit of dark magic inside assert_dataframe_equal(...) to make it simpler to read. We will come back to this after we cover a few other points.

Things to keep in mind

From the description above I make it sound like testing pyspark is just like testing python. But if you’ve tried this before you’ve probably realized that there are a few caveats that you should pay attention to. Here are a couple of them.

Shuffle partitions

If you’ve been using spark for a while, you’ll know that every time that you do a join or a group by spark has to do a “shuffle” to position data from the same Keys in the same machine. This is the famous “reduce” in the Map Reduce paradigm. If you have a very big dataset and use the default 200 for the number of shuffle partitions each of the 200 pieces will still be very big and your data will spill to disk which will make your job much slower. So usually when you want to change shuffle partitions you want to make it bigger.

In the case of unit-tests its the oposite your dataset has just a few lines. If you use the default most of the partitions will be empty, but they will be created and require coordination from the scheduler. This makes our test run slower than it has to. To avoid this, we set the shuffle partitions to 1 for the tests.

spark.conf.set("spark.sql.shuffle.partitions", 1)

Even with this optimization, starting up a spark session still takes a few seconds, so your pyspark tests will take a bit longer than you normal python tests. :/

Minimal test schema

Another point of friction is that datalake tables tend to be very wide to avoid joins. Writing out all of the columns for a test case can be very demotivating.

# sales schema
transaction_id	sku	price	user_id   date  country

# users schema
user_id	name  e-mail  favorite-color  nick-name

# output schema
transaction_id	sku	price	user_id   date  country name  e-mail  favorite-color  nick-name

In our test case actually uses the user_id column so we can write our tests using only that column. Well, actually we need an extra column from each table to be able to verify the join.

# sales schema
transaction_id	user_id

# users schema
user_id	name

# output schema
transaction_id	user_id name

For the output schema it can be restricted to only the columns you want to check. So if in another test we want to check if the price was correctly converted into dollars, then we no longer need to check the user_id.

# output schema
transaction_id	usd_pricer

Compare function

Comparing spark dataframes can be a bit tricky. There are many ways two dataframes can be different. They can have different columns, different column types or different column and they can differ in value. The very least that we need is that any one of these errors triggers an assert failure. We can do this by ordering the dataframes then collecting them.

def assert_dataframes_equal(df, df_ref):
    cols = df_ref.columns

    # order datasets
    actual_df = df.select(cols).orderBy(cols)
    expected_df = df_ref.orderBy(cols)

    if actual_df.collect() != expected_df.collect():
        assert False
    assert True

This function will serve its purpose, but when the assertion fails you will be left with some very unhelpful error messages. This is very bad for testing because if you need to push a new feature and you can’t figure out what is failing in the tests, that can be very frustrating.

I have slowly built up an assert function that gives help full messages upon failures. The full code is on github. I intend on creating a small library with it when I have the time. Here is rough sketch of the extra steps for better messages.

def assert_dataframes_equal(df, df_ref, order_cols=None):
    cols = df_ref.columns

    missing_cols = set(cols) - set(df.columns)
    if missing_cols:
        print_diff_cols(df_ref, df)
        assert False

    assert_column_types(df_ref, df)

    if not order_cols:
        order_cols = cols

    # order datasets
    actual_df = df.select(cols).orderBy(order_cols)
    expected_df = df_ref.orderBy(order_cols)

    actual = actual_df.collect()
    expected = expected_df.collect()
    if actual != expected:
        print_full_diff(expected_df, actual_df)
        assert False
    assert True

Excuses for not writing pyspark tests

Finally, lets go over a few excuses that might keep you from testing your code and try to debunk them real quick.

“I prefer sql, pyspark is too verbose”

For those of you who prefer to write in sql, here is a way to still do it:

def transform(sales: DataFrame, users: DataFrame) -> DataFrame:
    sales.createOrReplaceTempView("sales")
    users.createOrReplaceTempView("users")
    query = """
      select * from sales s left join users on s.user_id = u.users_id
    """
    df = spark.sql(query)
    return df

I realize the irony that this is more verbose than the pure pyspark version. But hey, its in sql and now you can test it! :)

“Its just a join”

I’ve heard this a lot: “its too simple to test”. Even in the simplest cases you’d be surprised at how many times there can be a bug. But a better argument for testing the simple cases is that they inevitably will become more complex over time, and when it does, your test is already setup for you to extend it.

In our sales example we will soon receive the request to convert different currencies into dollars or discount different country taxes.

def transform(sales, users, exchange_rates, tax_rates):
    # exchange rates
    df = sales.join(exchange_rates, on=['date', 'currency'], how="left")
    df = df.withColumn("usd_price", df.price * df.exchange_rate)

    # taxes
    df = df.join(tax_rates, on=['country'], how="left")
    df = df.withColumn("taxes", df.usd_price * df.tax_rate)

    df = sales.join(users, on=["user_id"], how="left")
    return df

That’s it, Happy testing !!!

Full code at: testing-pyspark