xarray is a library working with high-dimensional datasets. It uses dask to stream data from disk when it doesn’t fit in main memory. When I try to wrap an xarray DataArray
in a torch DataLoader
, I the program stalls. It appears to be a deadlock.
I’ve tried setting lock=False
in xr.open_mfdataset
which is passed to dask.array.from_array
, but it’s not helping.
Does anyone here have experience combining xarray or dask with PyTorch?
Here’s a minimal example of what I’m talking about:
from torch.utils.data import Dataset, DataLoader
import xarray as xr
class XarrayDataset(Dataset):
'''A simple torch Dataset wrapping xr.DataArray'''
def __init__(self, ar, batch_dim):
self.ar = ar
self.batch_dim = batch_dim
def __len__(self):
return len(self.ar[self.batch_dim])
def __getitem__(self, idx):
return self.ar[{self.batch_dim: idx}].values
class XarrayDataLoader(DataLoader):
'''A simple torch DataLoader wrapping xr.DataArray'''
def __init__(self, ar, batch_dim, **kwargs):
ar = XarrayDataset(ar, batch_dim)
super().__init__(ar, **kwargs)
# Open the dataset with xarray and select the feature to work with.
ds = xr.open_mfdataset(['NAM-NMM/nam.20170115/nam.t00z.awphys.tm00.nc', 'NAM-NMM/nam.20170115/nam.t12z.awphys.tm00.nc'])
ar = ds['TMP_SFC']
# If I load the data into main memory, the deadlock does not happen
#ar.load()
# Trying to load the data with num_workers > 0 results in deadlock.
dl = XarrayDataLoader(ar, batch_dim='reftime', num_workers=1)
i = iter(dl)
next(i)