Creating a Randomly Sampled Working Data in Spark and Python from Original Dataset

Arup Nanda
Dev Genius
Published in
14 min readJan 3, 2022

--

Learn how to create a statistically significant sample from a large dataset to accelerate iterative development, eliminate bias from ML training and more.

A slice of a pie
<a href=’https://www.freepik.com/photos/food'>Food photo created by freepik — www.freepik.com</a>

Why PySpark

Spark using python has become a very popular approach for data engineering systems. It combines the best of everything: Spark for faster in-memory processing (assuming you have the memory, of course); and the ubiquity of useful python libraries and availability of developers well versed with that language. Combine that with SQL, which is the de facto language of data processing such as analysis, scrubbing, wrangling and more and you have a decent data engineering system.

The Challenge of Large Datasets

While developing applications, you often need to work with real datasets in your organization; not made up or non-relevant ones that you would use while learning. But there is a little problem: real datasets are often very large and when you are writing programs, you want a significantly smaller subset of that to make the development process faster. Perhaps you are developing it in a much smaller server with less memory, a cloud container, or even your laptop where memory is severely limited. A full-blown dataset will probably result in out-of-memory errors, especially when executing operations like collect(), or at least slow down your iterative code development.

What can you do? You could take a slice of the actual data file using a simple Unix tool like head; but that will make your data suspiciously skewed. The first n-lines of the file may belong to a single customer; so when you operate on the file for tasks like identifying distribution patterns or training an ML model, it will result in bad development, an ineffective model and mask any potential bugs. You need a smaller subset of the file; but it has to be randomly selected from the original file, i.e. it has to be a statistically significant sample. In this article, you will learn how to create samples of data from large files to represent a specific percentage of the actual data; but selected randomly. You will also learn how to influence that percentage based on actual distribution to remove the skew even more.

Set up

You will need an actual data file. For the purpose of this article, let’s get a real dataset that is publicly available. It’s the COVID and data across all the counties in all states in the United States. You can download the file from here. Here is the citation for the dataset:

New York Times. Coronavirus (Covid-19) Data in the United States: Live Data: us counties.txt. Ann Arbor, MI: Inter-university Consortium for Political and Social Research [distributor], 2020–12–08. https://doi.org/10.3886/E128303V1-105700

Download the file and keep it in the folder where you want to work on. At this point, I assume you have;

  • A Spark cluster running (or even a local Spark Cluster on your laptop). I plan to write another article on how to install and use a local Spark installation later.
  • You have installed PySpark package.

To make sure the installation is good, enter this on the command line:

pyspark

You should see a screen something like this (of course, your Spark and Python versions may be different)

Python 3.6.4 (v3.6.4:d48eceb, Dec 19 2017, 06:04:45) [MSC v.1900 32 bit (Intel)] on win32Type “help”, “copyright”, “credits” or “license” for more information.Using Spark’s default log4j profile: org/apache/spark/log4j-defaults.propertiesSetting default log level to “WARN”.To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/__ / .__/\_,_/_/ /_/\_\ version 2.4.5
/_/
Using Python version 3.6.4 (v3.6.4:d48eceb, Dec 19 2017 06:04:45)
SparkSession available as ‘spark’.
>>>

The screen will wait with the familiar python prompt (“>>>”). As it says on the screen, there is already the set of variables we need.

sc: this is the Spark Context. We will not need to use this

spark: this is the Spark SQL Session. This will be heavily used. If you don’t see this in the above output, you can create it in the PySpark instance by executing

from pyspark.sql import *
spark = SparkSession.builder.appName(‘Arup’).getOrCreate()

That’s it. Let’s get down to the meat of today’s objective.

Read the Data

First, we need to create a DataFrame from the CSV file we downloaded. Checking the file, we see that it has a header, which makes it easy for us to map into the columns of the DataFrame. That took care of the name of the columns; but what about the datatypes? We could supply them; but we could also use a neat trick called Schema Inference, where Spark will read a few lines and try to guess the data types from the actual data. Here is how our code to read the file into the DataFrame called dflooks like. We specified that the file has a header (which should supply the column names) and that the schema should be inferred.

>>> df = spark.read.options(header=True, inferSchema=True).csv(“us-counties.txt”)

Let’s confirm that we got all the records:

>>> df.count()
758243

Did it import all the column's names and infer the data types? We can check the schema of the DataFrame.

>>> df.printSchema()root
| — date: timestamp (nullable = true)
| — county: string (nullable = true)
| — state: string (nullable = true)
| — fips: integer (nullable = true)
| — cases: integer (nullable = true)
| — deaths: integer (nullable = true)

Looks good so far. Let’s check one record:

>>> df.first()Row(date=datetime.datetime(2020, 1, 21, 0, 0), county=’Snohomish’, state=’Washington’, fips=53061, cases=1, deaths=0)

All looks good during the import process.

Fixed Sampling

Now to create a sample from this DataFrame. Spark provides a function called sample() that takes one argument — the percentage of the overall data to be sampled. Let’s use it. For now, let’s use 0.001%, or 0.00001 as the sampling ratio. Also, since we are just experimenting, we will not save to another DataFrame; we will just examine the results.

>>> df.sample(.00001).show()+-------------------+----------+--------------+-----+-----+------+
| date| county| state| fips|cases|deaths|
+-------------------+----------+--------------+-----+-----+------+
|2020-04-13 00:00:00| Bay| Michigan|26017| 59| 2|
|2020-04-27 00:00:00| Lubbock| Texas|48303| 504| 43|
|2020-05-12 00:00:00| Kimble| Texas|48267| 1| 0|
|2020-05-24 00:00:00| Thomas| Nebraska|31171| 1| 0|
|2020-06-21 00:00:00| Cedar| Nebraska|31027| 9| 0|
|2020-07-02 00:00:00|Los Alamos| New Mexico|35028| 8| 0|
|2020-07-13 00:00:00| Montrose| Colorado| 8085| 218| 12|
|2020-09-13 00:00:00| Mathews| Virginia|51115| 23| 0|
|2020-11-02 00:00:00| Surry|North Carolina|37171| 2009| 33|
+-------------------+----------+--------------+-----+-----+------+

We got 9 records, which is approximately 0.001% of the 758243 records in the file. Let’s run that command one more time:

>>> df.sample(.00001).show()+-------------------+-------------------+-------------+-----+-----+------+
| date| county| state| fips|cases|deaths|
+-------------------+-------------------+-------------+-----+-----+------+
|2020-03-30 00:00:00| Vilas| Wisconsin|55125| 3| 0|
|2020-04-20 00:00:00| Coffee| Alabama| 1031| 64| 0|
|2020-05-16 00:00:00| Clay|West Virginia|54015| 2| 0|
|2020-06-12 00:00:00| Christian| Missouri|29043| 29| 0|
|2020-07-22 00:00:00| Wood|West Virginia|54107| 216| 3|
|2020-07-28 00:00:00| El Paso| Texas|48141|13552| 239|
|2020-08-16 00:00:00| Zapata| Texas|48505| 203| 3|
|2020-09-11 00:00:00| Chester| Pennsylvania|42029| 6187| 362|
|2020-09-30 00:00:00| Roger Mills| Oklahoma|40129| 59| 1|
|2020-10-16 00:00:00| McLennan| Texas|48309| 9290| 134|
|2020-11-02 00:00:00|Virginia Beach city| Virginia|51810| 8248| 107|
+-------------------+-------------------+-------------+-----+-----+------+

We got 11 records, and also different records, which was expected. Sampling is meant to be random; so each pull will select different records. Let’s give that another go:

>>> df.sample(.00001).show()+ — — — — — — — — — -+ — — — — -+ — — — — — — + — — -+ — — -+ — — — +
| date | county | state | fips|cases|deaths|
+ — — — — — — — — — -+ — — — — -+ — — — — — — + — — -+ — — -+ — |2020–07–02 00:00:00| Paulding| Georgia |13223| 638| 16|
|2020–07–17 00:00:00| Doña Ana| New Mexico |35013| 1656| 12|
|2020–08–22 00:00:00|Mendocino| California | 6045| 624| 16|
|2020–09–01 00:00:00| Brantley| Georgia |13025| 310| 8|
|2020–10–04 00:00:00|Lancaster| Pennsylvania |42071| 8195| 459|
|2020–11–21 00:00:00| Grant | Kentucky |21081| 540| 6|
+ — — — — — — — — — -+ — — — — -+ — — — — — — + ----+ —---+ — — — +

We got only 6 records, and also different ones from the last two runs. The reason for the numbers to be different each time is because sampling percentage is an approximation, not an exact value. So the numbers will in the vicinity of 0.001% of the original number of records; but not exactly the same.

Seeding

This feature of different records and different amounts being pulled each time (of course as a property of true randomness) is probably fine with many development cases and what you may need. However, in other cases you would want a sample; but would want that sample to be predictable with each call, i.e. the sample should be random the first time but should not change with subsequent calls. How can you achieve that?

To reproduce the same output, fortunately the function comes with an optional parameter called seed. This is how you use it. In this example I have used 9876543210 as a seed.

>>> df.sample(.00001,9876543210).show()+ — — — — — — — — — -+ — — — — — + — — — — — -+ — — -+ — — -+ — — — +
| date| county| state| fips|cases|deaths|
+ — — — — — — — — — -+ — — — — — + — — — — — -+ — — -+ — — -+ — — — +
|2020–07–24 00:00:00| Delaware| Indiana|18035| 541| 53|
|2020–08–04 00:00:00| Blaine| Idaho|16013| 571| 6|
|2020–08–10 00:00:00| Auglaize| Ohio|39011| 256| 6|
|2020–09–14 00:00:00|Deer Lodge| Montana|30023| 87| 0|
|2020–09–28 00:00:00| Campbell| Virginia|51031| 453| 4|
|2020–10–19 00:00:00| Clay|Mississippi|28025| 672| 21|
+ — — — — — — — — — -+ — — — — — + — — — — — -+ — — -+ — — -+ — — — +

Running it another time:

>>> df.sample(.00001,9876543210).show()+ — — — — — — — — — -+ — — — — — + — — — — — -+ — — -+ — — -+ — — — +
| date| county| state| fips|cases|deaths|
+ — — — — — — — — — -+ — — — — — + — — — — — -+ — — -+ — — -+ — — — +
|2020–07–24 00:00:00| Delaware| Indiana|18035| 541| 53|
|2020–08–04 00:00:00| Blaine| Idaho|16013| 571| 6|
|2020–08–10 00:00:00| Auglaize| Ohio|39011| 256| 6|
|2020–09–14 00:00:00|Deer Lodge| Montana|30023| 87| 0|
|2020–09–28 00:00:00| Campbell| Virginia|51031| 453| 4|
|2020–10–19 00:00:00| Clay|Mississippi|28025| 672| 21|
+ — — — — — — — — — -+ — — — — — + — — — — — -+ — — -+ — — -+ — — — +

Each time the same sample came out. What if you want a different sample? Simply use a different seed, or omit the seed altogether.

>>> df.sample(.00001,8765432109).show()+ — — — — — — — — — -+ — — — — + — — — — — — + — — -+ — — -+ — — — +
| date | county | state | fips|cases|deaths|
+ — — — — — — — — — -+ — — — — + — — — — — — + — — -+ — — -+ — — — +
|2020–03–27 00:00:00|Franklin| Washington |53021| 12| 0|
|2020–05–08 00:00:00| Pasco | Florida |12101| 291| 9|
|2020–05–10 00:00:00|Mariposa| California | 6043| 15| 0|
|2020–06–28 00:00:00| Gregg | Texas |48183| 340| 14|
|2020–09–02 00:00:00| Benson |North Dakota|38005| 239| 4|
|2020–10–24 00:00:00| Steuben| New York |36101| 951| 69|
+ — — — — — — — — — -+ — — — — + — — — — — — + — — -+ — — -+ — — — +

Duplicates Allowed?

While sampling, when the function encounters a value it has already picked earlier, it suppresses that duplicate value. That might not be acceptable to you. If the data distribution is such as a single value, e.g. “New York” happens to occur a lot, chances are it will come up a lot. However, since sample() suppresses values it picked earlier, that value will be artificially unpicked during sampling. The function has another parameter to enable that.

To allow possible duplicate values, you can use the first parameter withReplacements that is by default False. By setting to true, you can get possible duplicate values. Here is an example. I used 1 as a seed so that I can get the same sample each time.

>>> df.select(“state”).sample(True, .00001, seed=1).show()+ — — — — — +
| state |
+ — — — — — +
| Wisconsin |
| Vermont |
| Texas |
|Washington |
| Georgia |
|New Mexico |
| Tennessee |
| Louisiana |
| Hawaii |
| Georgia |
| Wisconsin |
| Wyoming |
+ — — — — — +

You can see that the output has some entries, e.g. “Wisconsin” repeated. The function merely presented whatever it picked during sampling; it didn’t reject the ones picked earlier. Now let’s reissue the same with the parameter set to False, or omit the parameter altogether since False is the default.

>>> df.select(“state”).sample(False, .00001, seed=1).show()+ — — — — — +
| state |
+ — — — — — +
| Illinois |
| Georgia |
|Washington |
| Virginia |
| Minnesota |
+ — — — — — +

As you can see the entries are all unique; not repeated.

Reducing Skew in Sampling

Now that you understand how sampling works, you probably realize that the sampling percentage I used (0.001%) will be applied to all the cases. It means the values with a higher frequency of occurrence will be picked up more often. If that is what you want, you don’t need to go any further; but sometimes you may want to reduce the skewing effect of often repeated data elements. Allow me to explain with an example.

First, let’s get a count of records for each state:

>>> from pyspark.sql.functions import desc>>> df.groupBy(“state”).count().alias(“cnt”).sort(desc(“cnt.count”)).show()+ — — — — — — — — — — + — — -+
| state|count|
+ — — — — — — — — — — + — — -+
| Texas|57096|
| Georgia|38969|
| Virginia|31699|
| Kentucky|28132|
| Missouri|26770|
| Illinois|24258|
| North Carolina|24108|
… output truncated for brevity …| Massachusetts| 3859|
| Arizona| 3813|
| Vermont| 3703|
| Nevada| 3678|
| New Hampshire| 2705|
| Connecticut| 2258|
| Rhode Island| 1482|
| Hawaii| 1050|
| Delaware| 983|
| Virgin Islands| 824|
|Northern Mariana …| 410|
|District of Columbia| 261|
| Guam| 253|
+ — — — — — — — — — — + — — -+

As you can see, some states have a lot of records, e.g. Texas while Guam has a small number of cases. If you sample with a fixed 0.001% rate, the chances of Guam coming up in the pick are slim; in fact, it may not even come up. Similarly, high incidence states like Texas, Georgia, etc. will come up a lot more. This may result in incorrect test results or training results for your ML models.

To address it, you may want to create a striated sampling rate, instead of a single sampling percentage applied to all records, you want to apply different sampling percentages to different records based on the state. For sake of simplicity, let’s plan this:

  • For states with 16,000 records or more, the rate it 0.001%
  • For states with less than 16,000 records, the rate if 0.01%.

Fortunately, Spark provides another function, sampleBy() that allows us to do this. This function takes two parameters:

  • Column Name: the column which will be checked to come up with the sampling percent. In this case, it will be “state”
  • The fraction: what fraction will be applied to that column value, in a dict format. For instance, if it is set to this:
{
‘Texas’ : 0.00001,
‘Guam’ : 0.0001,
… and so on …
}

then sampling rate will be applied based on the value in the state column.

There is a third parameter — seed — just like the earlier case of the sample() function. It’s optional and works just like that case.

However, it will not be possible to manually construct the dict object for the fractions parameter. It will be quite a lot of work to be done manually. Besides, the distributions may change and we should automate it to be consistent with the actual data pattern. Here is how we do it. First, let’s get the counts into a new DataFrame:

df1 = df.groupBy(“state”).count().alias(“cnt”).sort(desc(“cnt.count”))

We will use a function called lit() that adds a literal to a DataFrame. While at it, let’s also import some other functions we need.

from pyspark.sql.functions import lit, when, col

Using lit() we will add the sampling ratio to the corresponding state. We will store it in another dataframe. The sampling ratio is stored under a column named “stratum”.

df2 = df1.withColumn("stratum",
when (col(“count”) > (16000), lit(0.00001))
.otherwise(lit(0.0001))
)

Now, we will create a dict object called fractions to create the state to sampling ratio mapping we need:

fractions = df2.select(“state”,”stratum”).rdd.collectAsMap()

Let’s peek into the dict variable:

>>> print(fractions){‘Texas’: 1e-05, ‘Georgia’: 1e-05, ‘Virginia’: 1e-05, ‘Kentucky’: 1e-05, ‘Missouri’: 1e-05, ‘Illinois’: 1e-05, ‘North Carolina’: 1e-05, ‘Iowa’: 1e-05, ‘Tennessee’: 1e-05, ‘Kansas’: 1e-05, ‘Indiana’: 1e-05, ‘Ohio’: 1e-05, ‘Minnesota’: 1e-05, ‘Michigan’: 1e-05, ‘Mississippi’: 1e-05, ‘Nebraska’: 1e-05, ‘Arkansas’: 1e-05, ‘Oklahoma’: 1e-05, ‘Wisconsin’: 1e-05, ‘Florida’: 1e-05, ‘Pennsylvania’: 1e-05, ‘Alabama’: 1e-05, ‘Louisiana’: 0.0001, ‘Puerto Rico’: 0.0001, ‘Colorado’: 0.0001, ‘New York’: 0.0001, ‘California’: 0.0001, ‘South Dakota’: 0.0001, ‘West Virginia’: 0.0001, ‘North Dakota’: 0.0001, ‘South Carolina’: 0.0001, ‘Montana’: 0.0001, ‘Washington’: 0.0001, ‘Idaho’: 0.0001, ‘Oregon’: 0.0001, ‘New Mexico’: 0.0001, ‘Utah’: 0.0001, ‘Maryland’: 0.0001, ‘New Jersey’: 0.0001, ‘Alaska’: 0.0001, ‘Wyoming’: 0.0001, ‘Maine’: 0.0001, ‘Massachusetts’: 0.0001, ‘Arizona’: 0.0001, ‘Vermont’: 0.0001, ‘Nevada’: 0.0001, ‘New Hampshire’: 0.0001, ‘Connecticut’: 0.0001, ‘Rhode Island’: 0.0001, ‘Hawaii’: 0.0001, ‘Delaware’: 0.0001, ‘Virgin Islands’: 0.0001, ‘Northern Mariana Islands’: 0.0001, ‘District of Columbia’: 0.0001, ‘Guam’: 0.0001}

This is exactly what we need to pass to the sampleBy() function. The above may not be eminently readable. You can use the following to read it easily.

>>> for i in fractions:
print(i, fractions[i])
Washington 0.0001
Illinois 1e-05
California 1e-05
Arizona 0.0001
Massachusetts 0.0001
… output truncated …

Now, let’s sample the dataset:

sampleDF = df.sampleBy("state",fractions,1).show(1000)
sampleDF.show(1000)
+-------------------+--------------------+-------------+-----+-----+------+
| date| county| state| fips|cases|deaths|
+-------------------+--------------------+-------------+-----+-----+------+
|2020-03-24 00:00:00| Yuma| Arizona| 4027| 2| 0|
|2020-03-25 00:00:00| Webster| Louisiana|22119| 5| 1|
|2020-04-11 00:00:00| Orleans| Louisiana|22071| 5535| 232|
|2020-05-09 00:00:00| Musselshell| Montana|30065| 1| 0|
|2020-05-10 00:00:00| Vernon| Louisiana|22115| 18| 2|
|2020-05-20 00:00:00| Glenn| California| 6021| 12| 0|
|2020-05-23 00:00:00| Carson City| Nevada|32510| 83| 4|
|2020-06-01 00:00:00|Dillingham Census...| Alaska| 2070| 1| 0|
… output truncated …

And that’s it; your working dataset is sampleDF, which has been carefully sampled with varying sampling ratios based on the actual data distribution, which will reduce the skew effect. This is merely a simple example of how you can apply just two ratios. You can modify the code that creates the df2 DataFrame to add as many ratios as you need. Ultimately the objective is to create the dictionary object to pass to the sampleBy() function.

Takeaways

Here is a quick summary of the article.

  1. Spark provides a function called sample() that pulls a random sample of data from the original file. The sampling rate is fixed for all records. This, since this is uniformly random, frequently occurring values, will show up more often in the sample, skewing the data.
  2. Spark provides another function called sampleBy() that pulls a random sample as well; but the sampling rate can be made dependent on the contents of the file. This allows you to reduce the occurrence of often used values in the resulting sample reducing skew.

Further Reads

You may read more about the sampleBy() function at the official Apache Spark documentation page https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.DataFrameStatFunctions.sampleBy.html. However, I have to warn, the documentation is very sparse. But the parent documentation page shows the extended syntax of many of the commands shown in the article, which may be helpful.

--

--

Award winning data/analytics/ML and engineering leader, raspberry pi junkie, dad and husband.