Correct way to subclass DataLoader __iter__ method to apply transforms to minibatch?

Hi,

I’m trying to add a transform to minibatches as they’re yielded from the DataLoader class. I understand that I can do this in the Dataset, but in my case, I’m applying a numpy function whose performance is nearly identical sample-wise and mini-batch wise. Therefore, with batch_size=4, it’s 4x faster for me to apply this transform after the DataLoader is created.

The most obvious option is

for X, y in dataloader:
     X = transform(X)

But I’d also like to use this DataLoader with PyTorch-Lightning, so I have to subclass it. I’ve done the following:

class SampleLoader(torch.utils.data.DataLoader):
    def __iter__(self):
        for batch in super().__iter__():
            yield transform(batch[0]), batch[1]

My question is, is this safe with num_workers>0 and is this correct? It appears to be with limited testing. Thanks in advance!!

You can also write a custom collation function that performs the transformation as part of collation and pass that to DataLoader (the argument name is collate_fn).

This part of the documentation talks about automatic batching and the application of collate_fn within DataLoader.

1 Like