Hello,
I have written a small class to be able to quickly experiment with new datasets. I intercept the __iter__
call of a data loader and apply some transformations to the data before returning them. I am now wondering if the following code is safe and works correctly with multiple GPUs. This suspicion comes from the fact that I am battling with an heisenbug at the moment, so I thought it could be related to cuda synchronization or something. Note though that the heisenbug occurs even when using one GPU.
Anyhow, here is the code for the TransformedDataLoader class:
class TransformedDataLoader:
'''
This class is supposed to be used as a wrapper around data loaders. It will apply the transformations passed in input
before returning the data from the data loader.
If use_cuda is True, it will output cuda tensors
'''
def __init__(self, data_loader, transforms, use_cuda=True, new_sampler=False):
self.data_loader = copy.deepcopy(data_loader)
self.use_cuda = use_cuda and helpers.USE_CUDA
if transforms is None:
transforms = []
if not isinstance(transforms, list):
transforms = [transforms]
self.transforms = transforms
if new_sampler is not False and new_sampler is not None:
if isinstance(new_sampler, str) and new_sampler.lower() == 'sequential':
new_sampler = torch.utils.data.sampler.SequentialSampler(self.data_loader.dataset)
if not isinstance(new_sampler, torch.utils.data.sampler.Sampler):
raise ValueError('The sampler passed must be an instance of torch.utils.data.sampler.Sampler')
self.data_loader.sampler = new_sampler
self.data_loader.batch_sampler = torch.utils.data.sampler.BatchSampler(new_sampler,
data_loader.batch_size,
data_loader.drop_last)
def __len__(self):
return len(self.data_loader)
def __iter__(self):
def modified_iter():
for data in self.data_loader:
if self.use_cuda:
if isinstance(data, (tuple, list)):
data = [obj.cuda() for obj in data]
else:
data = data.cuda()
for transform in self.transforms:
data = transform(data)
if isinstance(data, (tuple, list)) and len(data) == 1:
data = data[0]
yield data
return modified_iter()