How to find the index of a list of tensors?

A = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
B = torch.tensor([[7,8,9],[4,5,6]])
some_function(A, B) -> torch.tensor([2, 1])

What I want to do is that get indices of A based on each tensor of B.
The reason why output is [2,1] is that each index of [7,8,9] is 2, and of [4,5,6] is 1.
It looks quite simple, but is very hard to find a way. Can anyone help me please?

Hi Minhyuk!

Try torch.where ((A == B.unsqueeze (1)).prod (dim = 2))[1].

Best.

K. Frank

A == B.unsqueeze(1) looks magical. thanks

Could you give me some insight how two tensors of shape (4,3) and (2,1,3) can be compared by ==?

Hi Minhyuk!

The shapes of the tensors in your original post were [4, 3] and [2, 3].
So you can just squeeze() away the singleton dimension in your second
tensor and do the same thing that you did before.

Best.

K. Frank

It gives error when I compare two tensors of shape (4,3) and (2,3)

RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

It definitely seems like unsqueezing the first dimension does something.