How to convert and extracting multiple images from tensor?

  1. image tensor shape = (561, 561, 3)
  2. convert/transform tensor shape as (3, 1, 561, 561)
  3. Now get three images from tensor.

NOTE: If tensor is (561, 561, 4) then converting shape should be (6, 1, 561, 561) and we should get 6 images.

Thanks in advance.

If you want to split the channels of the first image, you could use:

x = torch.randn(561, 561, 3)
xs = x.split(1, dim=2)

To permute the dimensions, use:

x_perm = x.permute(2, 0, 1)
x_perm = x_perm.unsqueeze(1)
print(x_perm.shape)
# > torch.Size([3, 1, 561, 561])

I don’t understand how 4 channels could yield 6 “images” so could you explain this use case a bit more?