I’m trying to add a transform to minibatches as they’re yielded from the
DataLoader class. I understand that I can do this in the
Dataset, but in my case, I’m applying a numpy function whose performance is nearly identical sample-wise and mini-batch wise. Therefore, with
batch_size=4, it’s 4x faster for me to apply this transform after the
DataLoader is created.
The most obvious option is
for X, y in dataloader: X = transform(X)
But I’d also like to use this DataLoader with
PyTorch-Lightning, so I have to subclass it. I’ve done the following:
class SampleLoader(torch.utils.data.DataLoader): def __iter__(self): for batch in super().__iter__(): yield transform(batch), batch
My question is, is this safe with
num_workers>0 and is this correct? It appears to be with limited testing. Thanks in advance!!