DataLoader parallelization/synchronization with zarr/xarray/dask

xarray is a common library for high-dimensional datasets (typically in geoinformation sciences, see example here below). It uses dask under the hood to access data from disk when it would not fit in memory. xarray datasets can be conveniently saved as zarr stores.

When I load my xarray.Dataset from my zarr store using xarray.open_zarr() to a torch.Dataset, and then wrap the torch.Dataset in a torch.Dataloader, the program stalls when num_workers > 0. From reading elsewhere (e.g.) I think it is a synchronization issue in accessing the data in the zarr store. I have been trying playing with arguments of the xarray.open_zarr() function without success. Clearly my understanding of the topic (parallelization/syncrhonization) is too limited at the moment, yet this is the (big) bottleneck in my model training experiments.

Here below a code snippet to reproduce the issue. And my current workaround which is to pass have torch.Dataloader(batch_size = 1, num_worker = 0) but pass multiple items to the torch.Dataset.__getitem__ method.

import numpy as np
import xarray as xr
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torch

### Create some data
precipitation = np.arange(16*16*4000).reshape(16,16,4000)
temperature = np.arange(16*16*4000).reshape(16,16,4000)
lat_pixel = np.arange(0,16)
lon_pixel = np.arange(0,16)
global_id = np.arange(4000)

data_vars = {
    'fnf': (["lat_pixel", "lon_pixel", "global_id"], temperature),
    'precipitation': (["lat_pixel", "lon_pixel", "global_id"], precipitation),
}
coords={"lat_pixel": lat_pixel,
        "lon_pixel": lon_pixel,
        "global_id": global_id
        }

ds = xr.Dataset(data_vars, coords)

### Save it as zarr store
ds.to_zarr("dummy.zarr")
### Define the Dataset
class XRData(Dataset):
    def __init__(self):
        self.data = xr.open_zarr("dummy.zarr").to_array()
  
    def __len__(self):
        return len(self.data.global_id)

    def __getitem__(self, idx):
        image_npy = self.data[..., idx].to_numpy()
        image = torch.as_tensor(image_npy, dtype = torch.float)
        return image

train_data = XRData()
train_data.__getitem__(0).shape # torch.Size([2, 16, 16]) as expected

### Define and test the Dataloader. This will stall for num_workers > 0 and prefetch_factor > 0.
train_dataloader = DataLoader(train_data, batch_size= 32, num_workers = 0, prefetch_factor=None)
for X in tqdm(train_dataloader):
    np.matmul(X,X) # do something
### WORKAROUND
class XRBatchData(Dataset):
    def __init__(self, batch_size):
        self.data = xr.open_zarr("dummy.zarr").to_array().transpose("global_id",..., "lat_pixel", "lon_pixel")
        self.batch_size = batch_size
        
    def __len__(self):
        return int(len(self.data.global_id)/self.batch_size)

    def __getitem__(self, idx):
        image_npy = self.data.isel(global_id = slice(idx*self.batch_size, (idx+1)*self.batch_size)).to_numpy()
        image = torch.as_tensor(image_npy, dtype = torch.float)
        return image

batch_size = 32
train_data = XRBatchData(batch_size)
train_data.__getitem__(0).shape

train_dataloader = DataLoader(train_data, batch_size= 1, num_workers = 0, prefetch_factor=None)
for X in tqdm(train_dataloader):
    np.matmul(X,X) # do something

Did you manage to address your issue?

Well, reading this thread Deadlock with DataLoader and xarray/dask I found that adding this to your import

import dask
dask.config.set(scheduler='synchronous')

Solves the issue about stalling the Dataloader when num_workers > 0

I tested your code and confirmed that it stalls when num_workers > 0. However, changing the dask config at the beginning of your script will work as it is intended.