My images have 6 channels. I would like to apply some transformations to them. The problem is that if I apply, for instance, a RandomHorizontalFlip
, then the resulting image has 3 channels only.
I thus tried to use a Lambda function to apply the same transformation to each channel. I came up with the following class:
class HorizontalFlipChannels(object):
def transform(self, img):
transformed_channels = []
for idx, channel in enumerate(img):
print(idx)
channel = transforms.ToPILImage()(channel)
channel = transforms.RandomHorizontalFlip()(channel)
channel = transforms.ToTensor()(channel)
transformed_channels.append(channel)
img = torch.cat(transformed_channels)
return img
The problem is that if I print the index of the channel (when I iterate over the images in one batch), then I get something like the following
0
0
1
2
1
0
2
1
0
3
3
4
2
1
5
3
4
2
4
5
3
...
Is this normal? I was expecting to see something like 0 1 2 3 4 5 0 1 2 3 4 5 ...
.