Lazy loading of wide dataset

Hi Pytorch community,

I am training a model on a very wide dataset (~500,000 features). To read the data from disc I use dask to load an xarray.core.dataarray.DataArray object to not load all the data in memory at once. I can load subsets of the data into memory with a numpy array as such: xarray[0:64,:].values. This loads 64 samples into memory in about 2 seconds. I then want to feed the data to the model one batch at a time, using batch size of 64, since this should have an estimated epoch time of 3 minutes given the total sample size (~6000). My problem is when I try to implement this functionality in the pytorch Datasetmodule. I want to create a class in which a DataLoader loads the data from disc into memory one batch at a time with a specified batch size. To that I made the following Datasetclass and wrapped it in a DataLoader.

class load_wide(Dataset):

    def __init__(self, xarray, labels):
        self.xarray = xarray
        self.labels = labels

    def __getitem__(self, item):
        data = self.xarray[item,:].values
        labels = self.labels[item]
        return data, labels

    def __len__(self):
        return len(self.xarray)

I then load the data like this:

train_ds = load_plink(xarray_train, labels_train)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)

for i, data in enumerate(train_dl):
    feats, labels = data
    preds = net(feats)
    loss = criterion(preds, labels)

This works fine, but it takes about 90 seconds to load a batch resulting in unreasonable training time. What went wrong in my implementation of dataloading, that causes it to load the data so slowly. Using the DataLoader causes a 45-fold slowdown of dataloading. Can someone explain the cause of this?
Also, how can you simultaneously evaluate the model on a validation set that is loaded one batch at a time?

I’m not familiar with the internal implementation of the xarray, but what seems to be different is the shuffling. Could you test your simple code snippet with random indices xarray[0:64,:].values instead of contiguous ones and compare the loading speed?

This was the exact cause of the issue. It seems like xarray is not a good fit to combine with pytorch’s dataloader class. I will look for alternative ways of loading the data. Thank you!