Reduce one of Tensor dimensions (convert 4D image to 3D)

I have a tensor of (4, 512, 512).
This is image with 4 channels.
How can I remove 4th channel and get (3, 512, 512) tensor?

Hi,
I think basic slicing will do the job for you. See this small example below:

x = torch.Tensor(4,512,512)
print(x.shape)
y = x[:3, :, :]
print(y.shape)