Channels Last and DataLoader

Looking at the tutorial for the channels last format available at (beta) Channels Last Memory Format in PyTorch, I decided to try it out, but found myself facing a roughly x6 performance penalty rather than any gain. Looking at it, I feel that this is because the memory reorganisation was happening on the GPU–the only time the Input variable is directly exposed is within the training loop. Obviously, restructuring the data once it’s loaded onto the GPU is undesirable.

However, does is there any way to change the channel ordering when working with the DataLoader class or using ImageLoader to load data? Neither class has its own memory format due to being iterables, and passing in lambda x: x.to(memory_format=torch.channels_last) or lambda x: x.contiguous(memory_format=torch.channels_last) to a transforms.Lambda function results in this error:
AttributeError: Can't pickle local object 'main.<locals>.<lambda>'

It would seem that a unique transform method would be needed?

Could you try to pass a custom transformation method instead of a lambda to the transforms?
Alternatively you could also transform the data inside the Dataset in case you are using a custom implementation (or want to use one).

You mean including a method like:

def channels_last(input):
    return(input.contiguous(memory_format=torch.channels_last))

and then putting that inside transforms.Lambda? That gives the same “can’t pickle” error AttributeError: Can't pickle local object 'main.<locals>.channels_last' if it’s included locally, and if it’s located globally then it appears to get stuck on something when inside the training loop–no error is thrown but it doesn’t complete an epoch (and system utilisation drops to nothing).

Writing a custom dataset would entail rewriting the whole functionality of ImageFolder, which I was hoping to avoid.

No, I was thinking about writing a transformation class, such as:

class ToChannelsLast:
    def __call__(self, x):
        if x.ndim == 3:
            x = x.unsqueeze(0)
        elif x.ndim !=4:
            raise RuntimeError
        return x.to(memory_format=torch.channels_last)

    def __repr__(self):
        return self.__class__.__name__ + '()'


import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    ToChannelsLast()
])

x = torch.randn(3, 256, 256)
out = transform(x)
out.is_contiguous(memory_format=torch.channels_last)
out.stride()

Ah, I wasn’t sure how transforms were written. It has the same results as the lambda approach, though–it won’t pickle if local and will endlessly loop when placed globally.

I was thinking the easiest way would be to subclass ImageFolder and alter the format before returning the item e.g.

class MyImageFolder(datasets.ImageFolder):        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        
        sample, target = super().__getitem__(self, index)
        sample = sample.to(memory_format=torch.channels_last)
        
        return sample, target

But that hangs in the exact same manner as well. The model trains fine (if with a big penalty to performance) if the conversion is done with input within the training loop, or left implicitly (which oddly has a lower performance hit, but still about 4x vs not changing the channel format.

I worked out (part of) the problem whilst working on another dataset–the endless looping appears to be a multiprocess issue, removing the workers argument from the DataLoader moves past that step. That, however, introduces a different error–i.e. channels last only works on batches, so if the unsqueeze happens then 5d tensor is introduced to the network and if it doesn’t then no reformatting can happen.

On the data loader level, then, it should be a specific collate function to alter the batch memory format?