Hi,
I’m trying to add a transform to minibatches as they’re yielded from the DataLoader
class. I understand that I can do this in the Dataset
, but in my case, I’m applying a numpy function whose performance is nearly identical sample-wise and mini-batch wise. Therefore, with batch_size=4
, it’s 4x faster for me to apply this transform after the DataLoader
is created.
The most obvious option is
for X, y in dataloader:
X = transform(X)
But I’d also like to use this DataLoader with PyTorch-Lightning
, so I have to subclass it. I’ve done the following:
class SampleLoader(torch.utils.data.DataLoader):
def __iter__(self):
for batch in super().__iter__():
yield transform(batch[0]), batch[1]
My question is, is this safe with num_workers>0
and is this correct? It appears to be with limited testing. Thanks in advance!!