Torchvision transformation arbitrary channels

Hi, I have “images” with a large number of channels and there are some transformations that I would like to apply to these. Specifically I would like to perform rotations, cropping etc. These are all implemented in torchvision, however these require PIL images as input. This is problem, as PIL images require a specific number range of channels. Is there a way to support an arbitrary number of channels?

If not, is there an efficient way to take a tensor, and for each channel convert it to a PIL image, perform my transformations, convert back to tensors and stack along the channel dimension? Kinda like a along-a-dimension map-reduce function (that is hopefully efficient, and in and of itself would be endlessly useful for other applications that I am working on).

Ideally I would like to not rely on PIL at all.

5 Likes

Hi,

I think what you want to do can be realized by writing a loader function able to load your multi-channel images and convert them into numpy.ndarrays because torchvision.transforms.ToTensor can handle PIL.Image and numpy.ndarrays.

How does this solve the problem?
Transforms like torchvision.transforms.RandomHorizontalFlip need to take in a PIL image, so using ToTensor doesn’t make sense.

@zeneofa were you able to figure out a solution?

I managed to solve it in two ways:

  1. Apply the Pytorch’s transforms channel by channel:
tfms = transforms.Compose([transforms.Resize(224),
                           transforms.RandomHorizontalFlip(),
                           transforms.RandomVerticalFlip(),
                           transforms.ToTensor()])

x=[]
#set a seed so the same transforms are applied to each channel
seed = np.random.randint(2147483647)
for ch in img:
    random.seed(seed)
    x.append(tfms(Image.fromarray(ch)))

#this is the multichannel transformed image (a torch tensor)
img_tfm = torch.cat(x)
  1. Use fastai transforms. I think this the best approach because it is really easy and ready to go:
#use fastai's transforms
tfms = get_transforms(flip_vert=True,max_warp=None,max_rotate=45,max_lighting=None,max_zoom=1.1)

#img is a torch tensor (how many channels you need)
fastai_img = fastai.vision.Image(img)
            
#apply the transforms to the image
for tfm in tfms:
    fastai_img = fastai_img.apply_tfms(tfm,size=224)

#this is the multichannel transformed image (a torch tensor)
img_tfm = fastai_img.data
1 Like

Actually you should use torch.manual_seed and not python random.seed, because the transforms functions use torch.rand(1) to generate a random number and you may get different results for different channels.

1 Like

We need to use both the random seed expressions:

random.seed(seed) 
torch.manual_seed(seed)

For details: The random seed for torchvision transforms is not fixed in PyTorch 1.6 · Issue #42331 · pytorch/pytorch · GitHub