Indexing 3D tensor using 2 1D ones

Hi Everyone, i’m facing a problem in indexing a tensor an i hope someone her could help.

i have a 3D tensor T with dimensions [batch_size,h,w] and two other tensors H and W both with dimension [batch_size]

What i need to do is gather (or create a mask) that takes for every element in the position b of the minibatch the value at the position [b,h[b],w[b]]

so for example for the element t[0] i would like to return the value t[0,h[0],w[0]]

right now I’m trying to use something like this:

        h = h.long()
        w = w.long()

        mask = torch.zeros_like(T).bool()
        mask[:,h, w] = 1
        print(torch.sum(mask))  # should be equal to batch_size
        value = torch.masked_select(T, mask)

but it is not working.

I found the solution to my problem, posting it if anyone needs it!


Thanks anyway, happy coding!