Problems using Dataloader for dask/xarray/netCDF data

Using multiple worker s(num_workers > 0) for loading the data with DataLoader, results in a strange behavior. Here is a sample code snippet:

import xarray as xr
class mydataset(torch.utils.data.Dataset):
    def __init__(self,datapath):
        self.ds = xr.open_dataset(datapath)
        self.labels = # read the lables from some dimensions of self.ds
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return #choose some value from the dimensions of self.ds


Using this method for loading the data I face NaNs or very extreme values which are not expected at all. There is another issue which might be also related.

This behavior might point towards missing synchronizations.
I’m not familiar with the implementation and usage of xarray, so could you post an example using random data, which shows this behavior?

I experienced similar problems.
Are you reading the dataset from netCDF or GRIB files right? I guess is something related to the fact that netCDF I/O is not thread-safe and there is some locking going on.
If you chunk your dataset with ds=ds.chunk(<your chunk options>) , you save the dataset to disk in zarr format (ds.to_zarr() ) and reopen the dataset with xr.open_zarr() then everything should work fine.
I additionally suggest to set dask.config.set(scheduler='synchronous') to speed up the data loading if num_workers > 0 in your DataLoader.