How to get the indices of a [batch_size, 3] matrix in dimension 0 that match a [1,3] vector

How to get the indices of a [batch_size, 3] matrix in dimension 0 that match a [1,3] vector.

E.g.
Matrix A = [[1,2,3],
[4,5,6],
[1,2,3]]

function(A, [1,2,3]) -> indices [0,2]

Hi @whoab,

most likely, there is a better solution somewhere out there, but it should do the job:

torch.unique((a == torch.Tensor([1, 2, 3])).nonzero()[:, 0])

I hope I could help you.

Regards,
Unity05