Transform CNN input dimension

Hi everyone :slight_smile:

I have a question regarding possible transformations of a CNN input. I know that the input should be in the form batch_size x number_of_channels x height x width. Let’s assume I have a tensor of shape [2, 3, 5, 5]. So two images with three channels each and the image is of size 5 by 5.

Is there a neat way to transform this input into 3 x [2, 1, 5, 5]? Basically, I would like to split each channel of every image in the batch and then have 3 such inputs. All I can think of are a bunch of nested for loops but maybe there is a better / more common way to do this?

Any help is very much appreciated!

All the best

x = torch.randn(2,3,5,5)
x_ = x.transpose(0,1).unsqueeze(2)
x_ = list(x_)

@klory Thank you so much, that is exactly what I was looking for! :slight_smile: