Deadlock with DataLoader and xarray/dask

xarray is a library working with high-dimensional datasets. It uses dask to stream data from disk when it doesn’t fit in main memory. When I try to wrap an xarray DataArray in a torch DataLoader, I the program stalls. It appears to be a deadlock.

I’ve tried setting lock=False in xr.open_mfdataset which is passed to dask.array.from_array, but it’s not helping.

Does anyone here have experience combining xarray or dask with PyTorch?

Here’s a minimal example of what I’m talking about:

from torch.utils.data import Dataset, DataLoader
import xarray as xr

class XarrayDataset(Dataset):
    '''A simple torch Dataset wrapping xr.DataArray'''
    def __init__(self, ar, batch_dim):
        self.ar = ar
        self.batch_dim = batch_dim

    def __len__(self):
        return len(self.ar[self.batch_dim])

    def __getitem__(self, idx):
        return self.ar[{self.batch_dim: idx}].values

class XarrayDataLoader(DataLoader):
    '''A simple torch DataLoader wrapping xr.DataArray'''
    def __init__(self, ar, batch_dim, **kwargs):
        ar = XarrayDataset(ar, batch_dim)
        super().__init__(ar, **kwargs)

# Open the dataset with xarray and select the feature to work with.
ds = xr.open_mfdataset(['NAM-NMM/nam.20170115/nam.t00z.awphys.tm00.nc', 'NAM-NMM/nam.20170115/nam.t12z.awphys.tm00.nc'])
ar = ds['TMP_SFC']

# If I load the data into main memory, the deadlock does not happen
#ar.load()

# Trying to load the data with num_workers > 0 results in deadlock.
dl = XarrayDataLoader(ar, batch_dim='reftime', num_workers=1)
i = iter(dl)
next(i)
2 Likes

I experienced similar problems.
Is something related to the fact that netCDF I/O is not thread-safe and there is some locking going on.
If you chunk your dataset with ds=ds.chunk(<your chunk options>), you save the dataset to disk in zarr format (ds.to_zarr()) and reopen the dataset with xr.open_zarr() then everything works fine :wink:
I additionally suggest to set dask.config.set(scheduler='synchronous') to speed up the data loading if num_workers > 0 in your DataLoader.

2 Likes