I work with large 3D volumes of medical images on which data augmentations are done which take some time, so using the multiprocessing of the dataloader is a must. Suppose the following example code:
import numpy as np
import time
from torch.utils.data import DataLoader, Dataset
class ExpensiveTransform():
def __call__(self, sample):
print('Sleeping...')
time.sleep(5)
return sample
class ToPatches():
def __call__(self, sample):
patches = [sample['id']] * 20
return patches
class CustomDataset(Dataset):
def __init__(self, transform):
self.transform = transform
def __len__(self):
return 10
def __getitem__(self, idx):
sample = np.random.rand(640, 580, 72)
sample = {'id': idx, 'image': sample}
if self.transform:
for transform in self.transform:
sample = transform(sample)
return sample
def main():
dataset = CustomDataset(transform=[ExpensiveTransform(), ToPatches()])
loader = DataLoader(dataset,
batch_size=1,
num_workers=2,
persistent_workers=True)
for epoch in range(10):
for patches in loader:
print(patches)
if __name__ == "__main__":
main()
This gives as an output:
Sleeping…
Sleeping…
[tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0])]
[tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1]), tensor([1])]
Sleeping…
Sleeping…
[tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2]), tensor([2])]
[tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3]), tensor([3])]
Sleeping…
Sleeping…
[tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4]), tensor([4])]
[tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5]), tensor([5])]
Sleeping…
Sleeping…
However, as I final result I want to get batches of let’s say 4 patches and in a random order. For example
[tensor([2]), tensor([7]), tensor([0]), tensor([1])]
[tensor([5]), tensor([4]), tensor([5]), tensor([0])]
In Tensorflow there are the .unbatch and the shuffle() methods in tf.data to achieve this. How could the same by achieved by PyTorch?