Deadlock with DataLoader and xarray/dask

xarray is a library working with high-dimensional datasets. It uses dask to stream data from disk when it doesn’t fit in main memory. When I try to wrap an xarray DataArray in a torch DataLoader, I the program stalls. It appears to be a deadlock.

I’ve tried setting lock=False in xr.open_mfdataset which is passed to dask.array.from_array, but it’s not helping.

Does anyone here have experience combining xarray or dask with PyTorch?

Here’s a minimal example of what I’m talking about:

from import Dataset, DataLoader
import xarray as xr

class XarrayDataset(Dataset):
    '''A simple torch Dataset wrapping xr.DataArray'''
    def __init__(self, ar, batch_dim): = ar
        self.batch_dim = batch_dim

    def __len__(self):
        return len([self.batch_dim])

    def __getitem__(self, idx):
        return[{self.batch_dim: idx}].values

class XarrayDataLoader(DataLoader):
    '''A simple torch DataLoader wrapping xr.DataArray'''
    def __init__(self, ar, batch_dim, **kwargs):
        ar = XarrayDataset(ar, batch_dim)
        super().__init__(ar, **kwargs)

# Open the dataset with xarray and select the feature to work with.
ds = xr.open_mfdataset(['NAM-NMM/nam.20170115/', 'NAM-NMM/nam.20170115/'])
ar = ds['TMP_SFC']

# If I load the data into main memory, the deadlock does not happen

# Trying to load the data with num_workers > 0 results in deadlock.
dl = XarrayDataLoader(ar, batch_dim='reftime', num_workers=1)
i = iter(dl)

I experienced similar problems.
Is something related to the fact that netCDF I/O is not thread-safe and there is some locking going on.
If you chunk your dataset with ds=ds.chunk(<your chunk options>), you save the dataset to disk in zarr format (ds.to_zarr()) and reopen the dataset with xr.open_zarr() then everything works fine :wink:
I additionally suggest to set dask.config.set(scheduler='synchronous') to speed up the data loading if num_workers > 0 in your DataLoader.