Select column based on binary 2D tensor

consider I have a matrix x = torch.randn(2,5)

I have a binary matrix b = torch.Tensor([[1,0,1,0,1],[0,1,1,1,0]]) # assume that b.sum(-1) are equal across dim=0; namely, each row has same number of one’s

I’m wondering if there is an efficient way such that it returns a matrix with shape NxD where N is number of rows in b and D is number of ones in each row of b; since we assume that each row of b has same number of ones so this may not raise error.

The return matrix, should have its first rowx[0,x[b[0,0]]],x[0,x[b[0,2]]],x[0,x[b[0,4]]] and second row `x[1,x[b[1,2]]],x[1,x[b[1,2]]],x[1,x[b[1,3]]]’

Note that x is output of a neural network but the returned new matrix will be in the loss function so I still need to backprogate. Any ideas ?


If you are sure they contain the same number of non zero elements D, you can do x[b].view(D, N). If you want an [N, D] result, you can transpose that with .t().