Filter tensors based on other tensor

I want to filter tensors. I have two tensors.
tensorA has dimensions [ 264 , 2 , 3] and tensorB has dimensions [264]. tensorB contains data like 0,1,0,0,0,1 … i want to filter tensorA on the basis of tensorB such that output tensor has dimensions [ 264 , 3] (from two indexes ( 0 and 1) of dim = 1 select either 0 or 1 based on tensorB).
Basically select index of dim = 1 of tensorA on the basis of tensor 2.
I tried using gather function but it says input and index dimensions are not same. i have tried using for loop but that is slowing my model a lot so i need fast way which works fast on gpu.

You have to expand the dims of the indexing tensor for gather, or at least I’m not aware of a way to circumvent that. This should work

a = torch.rand((264, 2, 3))
b = torch.randint(0, 2, (264,))
i = torch.stack(3*[b], dim=1).view(-1, 1, 3)
r = a.gather(1, i).squeeze()
1 Like