Is modifying a data loader safe?


I have written a small class to be able to quickly experiment with new datasets. I intercept the __iter__ call of a data loader and apply some transformations to the data before returning them. I am now wondering if the following code is safe and works correctly with multiple GPUs. This suspicion comes from the fact that I am battling with an heisenbug at the moment, so I thought it could be related to cuda synchronization or something. Note though that the heisenbug occurs even when using one GPU.

Anyhow, here is the code for the TransformedDataLoader class:

class TransformedDataLoader:

    This class is supposed to be used as a wrapper around data loaders. It will apply the transformations passed in input
    before returning the data from the data loader. 
    If  use_cuda is True, it will output cuda tensors

    def __init__(self, data_loader, transforms, use_cuda=True, new_sampler=False):

        self.data_loader = copy.deepcopy(data_loader)
        self.use_cuda = use_cuda and helpers.USE_CUDA

        if transforms is None:
            transforms = []

        if not isinstance(transforms, list):
            transforms = [transforms]

        self.transforms = transforms

        if new_sampler is not False and new_sampler is not None:
            if isinstance(new_sampler, str) and new_sampler.lower() == 'sequential':
                new_sampler =
            if not isinstance(new_sampler,
                raise ValueError('The sampler passed must be an instance of')

            self.data_loader.sampler = new_sampler
            self.data_loader.batch_sampler =,

    def __len__(self):
        return len(self.data_loader)

    def __iter__(self):

        def modified_iter():
            for data in self.data_loader:
                if self.use_cuda:
                    if isinstance(data, (tuple, list)):
                        data = [obj.cuda() for obj in data]
                        data = data.cuda()

                for transform in self.transforms:
                    data = transform(data)

                if isinstance(data, (tuple, list)) and len(data) == 1:
                    data = data[0]

                yield data

        return modified_iter()