Match tensors values

Hey everybody,

i have a problem with two tensors. The challenge is compare two tensors with size ten_A = (342,2) and ten_B = (2). In ten_A are values like [0,1],[0,2],[0,3],… and ten_B [0,1]. How can i compare the values an get the index of ten_A?

Thanks for help.

If you want to compare each row of ten_A to ten_B and get the row index for a match, this code should work:

a = torch.randint(0, 4, (10, 2))
b = torch.tensor([0, 1])
b = b.unsqueeze(0)

print((a == b).all(dim=1).nonzero())
1 Like