- image tensor shape = (561, 561, 3)
- convert/transform tensor shape as (3, 1, 561, 561)
- Now get three images from tensor.

NOTE: If tensor is (561, 561, 4) then converting shape should be (6, 1, 561, 561) and we should get 6 images.

Thanks in advance.

- image tensor shape = (561, 561, 3)
- convert/transform tensor shape as (3, 1, 561, 561)
- Now get three images from tensor.

NOTE: If tensor is (561, 561, 4) then converting shape should be (6, 1, 561, 561) and we should get 6 images.

Thanks in advance.

If you want to split the channels of the first image, you could use:

```
x = torch.randn(561, 561, 3)
xs = x.split(1, dim=2)
```

To permute the dimensions, use:

```
x_perm = x.permute(2, 0, 1)
x_perm = x_perm.unsqueeze(1)
print(x_perm.shape)
# > torch.Size([3, 1, 561, 561])
```

I don’t understand how 4 channels could yield 6 “images” so could you explain this use case a bit more?