How to get the indices of a [batch_size, 3] matrix in dimension 0 that match a [1,3] vector.
E.g.
Matrix A = [[1,2,3],
[4,5,6],
[1,2,3]]
function(A, [1,2,3]) -> indices [0,2]
How to get the indices of a [batch_size, 3] matrix in dimension 0 that match a [1,3] vector.
E.g.
Matrix A = [[1,2,3],
[4,5,6],
[1,2,3]]
function(A, [1,2,3]) -> indices [0,2]
Hi @whoab,
most likely, there is a better solution somewhere out there, but it should do the job:
torch.unique((a == torch.Tensor([1, 2, 3])).nonzero()[:, 0])
I hope I could help you.
Regards,
Unity05