Hi Everyone, i’m facing a problem in indexing a tensor an i hope someone her could help.
i have a 3D tensor T with dimensions [batch_size,h,w] and two other tensors H and W both with dimension [batch_size]
What i need to do is gather (or create a mask) that takes for every element in the position b of the minibatch the value at the position [b,h[b],w[b]]
so for example for the element t[0] i would like to return the value t[0,h[0],w[0]]
right now I’m trying to use something like this:
h = h.long()
w = w.long()
mask = torch.zeros_like(T).bool()
mask[:,h, w] = 1
print(torch.sum(mask)) # should be equal to batch_size
value = torch.masked_select(T, mask)
but it is not working.
EDIT:
I found the solution to my problem, posting it if anyone needs it!
t[torch.arange(t.shape[0]),h,w)]
Thanks anyway, happy coding!