A = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
B = torch.tensor([[7,8,9],[4,5,6]])
some_function(A, B) -> torch.tensor([2, 1])
What I want to do is that get indices of A based on each tensor of B.
The reason why output is [2,1] is that each index of [7,8,9] is 2, and of [4,5,6] is 1.
It looks quite simple, but is very hard to find a way. Can anyone help me please?
The shapes of the tensors in your original post were [4, 3] and [2, 3].
So you can just squeeze() away the singleton dimension in your second
tensor and do the same thing that you did before.