Access 3d tensor by 2d index

I am trying to access a 3d tensor matrix by 2d matrix. The return should be an 2d matrix too. And the following is what i try to achieve.


dim1 = 2
dim2 = 3
dim3 = 4
source = torch.FloatTensor(dim1, dim2, dim3)
source.normal_()

check = zn > 0.0
index = torch.argmax(check, dim=0)

for i in range(dim2):
    for j in range(dim3):
        ret[i, j] = source[index[i, j], i, j]

I think you want to use torch.gather:

import torch
dim1 = 2
dim2 = 3
dim3 = 4
source = torch.FloatTensor(dim1, dim2, dim3)
index = source.argmax(dim=0).unsqueeze(0)
source.gather(0, index).squeeze(0)

It is working, thank you so much!