Hi everyone,
I am trying to access values using torch.gather
I have a tensor of shape [B,K,64,64] and a tensor which represents the 2D indices [B,K,2]
The resulting tensor should be [B,K,1]
Hi everyone,
I am trying to access values using torch.gather
I have a tensor of shape [B,K,64,64] and a tensor which represents the 2D indices [B,K,2]
The resulting tensor should be [B,K,1]
Direct indexing should work:
B, K = 2, 3
x = torch.randn(B, K, 64, 64)
idx = torch.randint(0, 64, (B, K, 2))
out = x[torch.arange(x.size(0)).unsqueeze(1), torch.arange(x.size(1)), idx[:, :, 0], idx[:, :, 1]]
print(out.shape)
# torch.Size([2, 3])