Apply transformation to each channel of an image

My images have 6 channels. I would like to apply some transformations to them. The problem is that if I apply, for instance, a RandomHorizontalFlip, then the resulting image has 3 channels only.

I thus tried to use a Lambda function to apply the same transformation to each channel. I came up with the following class:

class HorizontalFlipChannels(object):
    
    def transform(self, img):
        
        transformed_channels = []
        
        for idx, channel in enumerate(img):
            print(idx)
            channel = transforms.ToPILImage()(channel)
            channel = transforms.RandomHorizontalFlip()(channel)
            channel = transforms.ToTensor()(channel)
            
            transformed_channels.append(channel)
        
        img = torch.cat(transformed_channels)
        
        return img

The problem is that if I print the index of the channel (when I iterate over the images in one batch), then I get something like the following

0
0
1
2
1
0
2
1
0
3
3
4
2
1
5
3
4
2
4
5
3
...

Is this normal? I was expecting to see something like 0 1 2 3 4 5 0 1 2 3 4 5 ....

Do you by any chance has 4 workers in your dataloader? If more than 0, then the program can do these things

1 Like

Yes, I have 4 workers. But I was wondering whether the final order of the channels for all the images will be the same.

Yes it will, so no worries :slight_smile: It’s just the multiprocessing that’s screwing with the printouts.

1 Like