Dask for Parallel Computing in Python¶
In past lectures, we learned how to use numpy, pandas, and xarray to analyze various types of geoscience data. In this lecture, we address an incresingly common problem: what happens if the data we wish to analyze is "big data"
Aside: What is "Big Data"?¶
There is a lot of hype around the buzzword "big data" today. Some people may associate "big data" with specific sortware platforms (e.g. "Hadoop", "spark"), while, for others, "big data" means specific machine learning techniques. But I think wikipedia's definition is the most useful
Big data is data sets that are so voluminous and complex that traditional data processing application software are inadequate to deal with them.
By this definition, a great many datasets we regularly confront in Earth science are big data.
A good threshold for when data becomes difficult to deal with is when the volume of data exceeds your computer's RAM. Most modern laptops have between 2 and 16 GB of RAM. High-end workstations and servers can have 1 TB (1000 GB) or RAM or more. If the dataset you are trying to analyze can't fit in you computer's memory, some special care is required to carry out the analysis. Data that can't fit in RAM but can fit on your hard drive is sometimes called "medium data."
The next threshold of difficulty is when the data can't fit on your hard drive. Most modern laptops have between 100 GB and 4 TB of storage space on the hard drive. If you can't fit your dataset on your internal hard drive, you can buy an external hard drive. However, at that point you are better off using a high-end server, HPC system, or cloud-based storage for your dataset. Once you have many TB of data to analyze, you are definitely in the realm of "big data"
What is Dask?¶
Dask is a tool that helps us easily extend our familiar python data analysis tools to medium and big data, i.e. dataset that can't fit in our computer's RAM. In many cases, dask also allows us to speed up our analysis by using mutiple CPU cores. Dask can help us work more efficiently on our laptop, and it can also help us scale up our analysis on HPC and cloud platforms. Most importantly, dask is almost invisible to the user, meaning that you can focus on your science, rather than the details of parallel computing.
Dask was created by the brilliant Matt Rocklin. You can learn more about it on
Dask provides collections for big data and a scheduler for parallel computing. It is probably easiest to illustrate what these mean through examples, so we will jump right in.
Dask Arrays¶
A dask array looks and feels a lot like a numpy array. However, a dask array doesn't directly hold any data. Instead, it symbolically represents the computations needed to generate the data. Nothing is actually computed until the actual numerical values are needed. This mode of operation is called "lazy"; it allows one to build up complex, large calculations symbolically before turning them over the scheduler for execution.
If we want to create a numpy array of all ones, we do it like this:
import numpy as np
shape = (1000, 4000)
ones_np = np.ones(shape)
ones_np
This array contains exactly 32 MB of data:
ones_np.nbytes / 1e6
Now let's create the same array using dask's array interface.
import dask.array as da
ones = da.ones(shape)
This did not work, because we didn't tell dask how to split up the array.
A crucal difference with dask is that we must specify the chunks
argument. "Chunks" describes how the array is split up over many sub-arrays.
source: Dask Array Documentation
There are several ways to specify chunks. In this lecture, we will use a block shape.
chunk_shape = (1000, 1000)
ones = da.ones(shape, chunks=chunk_shape)
ones
Notice that we just see a symbolic represetnation of the array, including its shape, dtype, and chunksize.
No data has been generated yet.
When we call .compute()
on a dask array, the computation is trigger and the dask array becomes a numpy array.
ones.compute()
In order to understand what happened when we called .compute()
, we can visualize the dask graph, the symbolic operations that make up the array
ones.visualize()
Our array has four chunks. To generate it, dask calls np.ones
four times and then concatenates this together into one array.
Rather than immediately loading a dask array (which puts all the data into RAM), it is more common to want to reduce the data somehow. For example
sum_of_ones = ones.sum()
sum_of_ones.visualize()
Here we see dask's strategy for finding the sum. This simple example illustrates the beauty of dask: it automatically designs an algorithm appropriate for custom operations with big data.
If we make our operation more complex, the graph gets more complex.
fancy_calculation = (ones * ones[::-1, ::-1]).mean()
fancy_calculation.visualize()
A Bigger Calculation¶
The examples above were toy examples; the data (32 MB) is nowhere nearly big enough to warrant the use of dask.
We can make it a lot bigger!
bigshape = (200000, 4000)
big_ones = da.ones(bigshape, chunks=chunk_shape)
big_ones
big_ones.nbytes / 1e6
This dataset is 3.2 GB, rather MB! This is probably close to or greater than the amount of available RAM than you have in your computer. Nevertheless, dask has no problem working on it.
Do not try to .visualize()
this array!
When doing a big calculation, dask also has some tools to help us understand what is happening under the hood
from dask.diagnostics import ProgressBar
big_calc = (big_ones * big_ones[::-1, ::-1]).mean()
with ProgressBar():
result = big_calc.compute()
result
Reduction¶
All the usual numpy methods work on dask arrays. You can also apply numpy function directly to a dask array, and it will stay lazy.
big_ones_reduce = (np.cos(big_ones)**2).mean(axis=0)
big_ones_reduce
Plotting also triggers computation, since we need the actual values
from matplotlib import pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (12,8)
plt.plot(big_ones_reduce)
Distributed Cluster¶
For more fancy visualization of what dask is doing, we can use the distributed scheduler.
from dask_kubernetes import KubeCluster
cluster = KubeCluster(nworkers=4)
cluster
from dask.distributed import Client
client = Client(cluster)
client
big_calc.compute()
random_values = da.random.normal(size=(2e8,), chunks=(1e6,))
hist, bins = da.histogram(random_values, bins=100, range=[-5, 5])
hist
x = 0.5 * (bins[1:] + bins[:-1])
width = np.diff(bins)
plt.bar(x, hist, width);
Dask + XArray¶
Xarray can automatically wrap its data in dask arrays. This capability turns xarray into an extremely powerful tool for Big Data earth science
To see this in action, we will download a fairly large dataset to analyze. This file contains 1 year of daily data from the AVISO sea-surface height satellite altimetry dataset.
! wget http://www.ldeo.columbia.edu/~rpa/aviso_madt_2015.tar.gz
! tar -xvzf aviso_madt_2015.tar.gz
! ls 2015 | wc -l
Let's load the first file as a regular xarray dataset.
import xarray as xr
ds_first = xr.open_dataset('2015/dt_global_allsat_madt_h_20150101_20150914.nc')
ds_first
ds_first.nbytes / 1e6
This one file is about 8 MB. So 365 of them will be nearly 3 GB. If we had downloaded all 25 years of data, it would be 73 GB. This is a good example of "medium data."
open_mfdataset¶
An incredibly useful function in xarray is open_mfdataset
.
help(xr.open_mfdataset)
Using open_mfdataset
we can easily open all the netcdf files into one Dataset
object.
# On I got a "Too many open files" OSError.
# It's only 365 files. That shouldn't be too many.
# However, I discovered my ulimit was extremely low.
# One workaround is to call
# $ ulimit -S -n 4000
# from the command line before launching the notebook
ds = xr.open_mfdataset('2015/*.nc')
ds
Note that the values are not displayed, since that would trigger computation.
ssh = ds.adt
ssh
ssh[0].plot()
ssh_2015_mean = ssh.mean(dim='time')
ssh_2015_mean.load()
ssh_2015_mean.plot()
ssh_anom = ssh - ssh_2015_mean
ssh_variance_lonmean = (ssh_anom**2).mean(dim=('lon', 'time'))
ssh_variance_lonmean.plot()
weight = np.cos(np.deg2rad(ds.lat))
weight /= weight.mean()
(ssh_anom * weight).mean(dim=('lon', 'lat')).plot()
Cloud Storage¶
import gcsfs
gcsmap = gcsfs.mapping.GCSMap('pangeo-data/dataset-duacs-rep-global-merged-allsat-phy-l4-v3-alt')
ds = xr.open_zarr(gcsmap)
ds