DataLoader parallelization/synchronization with zarr/xarray/dask

xarray is a common library for high-dimensional datasets (typically in geoinformation sciences, see example here below). It uses dask under the hood to access data from disk when it would not fit in memory. xarray datasets can be conveniently saved as zarr stores.

When I load my xarray.Dataset from my zarr store using xarray.open_zarr() to a torch.Dataset, and then wrap the torch.Dataset in a torch.Dataloader, the program stalls when num_workers > 0. From reading elsewhere (e.g.) I think it is a synchronization issue in accessing the data in the zarr store. I have been trying playing with arguments of the xarray.open_zarr() function without success. Clearly my understanding of the topic (parallelization/syncrhonization) is too limited at the moment, yet this is the (big) bottleneck in my model training experiments.

Here below a code snippet to reproduce the issue. And my current workaround which is to pass have torch.Dataloader(batch_size = 1, num_worker = 0) but pass multiple items to the torch.Dataset.__getitem__ method.

import numpy as np
import xarray as xr
from tqdm import tqdm
from import Dataset, DataLoader
import torch

### Create some data
precipitation = np.arange(16*16*4000).reshape(16,16,4000)
temperature = np.arange(16*16*4000).reshape(16,16,4000)
lat_pixel = np.arange(0,16)
lon_pixel = np.arange(0,16)
global_id = np.arange(4000)

data_vars = {
    'fnf': (["lat_pixel", "lon_pixel", "global_id"], temperature),
    'precipitation': (["lat_pixel", "lon_pixel", "global_id"], precipitation),
coords={"lat_pixel": lat_pixel,
        "lon_pixel": lon_pixel,
        "global_id": global_id

ds = xr.Dataset(data_vars, coords)

### Save it as zarr store
### Define the Dataset
class XRData(Dataset):
    def __init__(self): = xr.open_zarr("dummy.zarr").to_array()
    def __len__(self):
        return len(

    def __getitem__(self, idx):
        image_npy =[..., idx].to_numpy()
        image = torch.as_tensor(image_npy, dtype = torch.float)
        return image

train_data = XRData()
train_data.__getitem__(0).shape # torch.Size([2, 16, 16]) as expected

### Define and test the Dataloader. This will stall for num_workers > 0 and prefetch_factor > 0.
train_dataloader = DataLoader(train_data, batch_size= 32, num_workers = 0, prefetch_factor=None)
for X in tqdm(train_dataloader):
    np.matmul(X,X) # do something
class XRBatchData(Dataset):
    def __init__(self, batch_size): = xr.open_zarr("dummy.zarr").to_array().transpose("global_id",..., "lat_pixel", "lon_pixel")
        self.batch_size = batch_size
    def __len__(self):
        return int(len(

    def __getitem__(self, idx):
        image_npy = = slice(idx*self.batch_size, (idx+1)*self.batch_size)).to_numpy()
        image = torch.as_tensor(image_npy, dtype = torch.float)
        return image

batch_size = 32
train_data = XRBatchData(batch_size)

train_dataloader = DataLoader(train_data, batch_size= 1, num_workers = 0, prefetch_factor=None)
for X in tqdm(train_dataloader):
    np.matmul(X,X) # do something