How to take the value of center frame?

I have a tensor of shape (256, 7, 16, 2) where the 2nd dimension is the total number of frames (e.g: 7). How can I take the value of the center frame so that it becomes of shape (256, 1, 16, 2)?

I assume you are referring to the frame at index position 3 as the “center frame”. In this case, you could directly index and unsqueeze it or slice the tensor:

x = torch.randn(256, 7, 16, 2)

y = x[:, 3].unsqueeze(1)
print(y.shape)
# torch.Size([256, 1, 16, 2])

z = x[:, 3:4]
print(z.shape)
# torch.Size([256, 1, 16, 2])
1 Like