Compare two tensors and return total number of matches and the matching values

Hi all,

I have two same-size 3D tensors, each one is populated with 4 possible values (0,1,2,3)

I want to compare these two tensors element-wise and return the total number of matches where both tensors had 0 in the same index, and the total number of matches where both tensors had 1 in the same index etc until the matching value is 3.

Any efficient way of doing this? Thank you!

Does it need to be differentiable?

Otherwise:

a = torch.randint(0, 4, [2,3,4])
b = torch.randint(0, 4, [2,3,4])
print(a)
print(b)

print(torch.unique(a[a == b], return_counts=True))

output:

tensor([[[2, 1, 2, 1],
         [0, 0, 0, 0],
         [0, 0, 0, 1]],

        [[1, 3, 2, 3],
         [3, 3, 0, 0],
         [3, 1, 3, 3]]])

tensor([[[0, 3, 1, 3],
         [2, 0, 3, 3],
         [2, 0, 3, 1]],

        [[2, 2, 2, 1],
         [0, 2, 3, 0],
         [1, 0, 0, 2]]])

(tensor([0, 1, 2]), tensor([3, 1, 1]))
1 Like

Thank you so much. This is what I needed. Just a quick question, can you please tell me what does
a[a == b] do? I have never seen it before.

Also, in your example, if there is no matching for 3, so it would not return a value for that?

Thanks again!

We can divide it into 2 steps:

a == b returns a boolean tensor (a mask) where values are True if the both a and b has the same value. The good thing is that with PyTorch, this operation is performed element-wise. So it checks each item at each channel, column, row and performs this operation.

Once we have this mask, we can use it to select items from the original tensors. It doesnā€™t matter if we were to do a[a == b] or b[a == b] since we only extract the values where a and b are the same.

We can see how this works in practice:

a = torch.randint(0, 4, [2,3,4])
b = torch.randint(0, 4, [2,3,4])
print(a)
print(b)

mask = a == b
print(mask)

extracted_values = a[mask]
print(extracted_values)

Output:

# a
tensor([[[0, 1, 2, 3],
         [3, 3, 2, 2],
         [1, 0, 0, 0]],

        [[1, 2, 1, 1],
         [3, 2, 3, 3],
         [3, 2, 0, 3]]])

# b
tensor([[[3, 3, 0, 0],
         [1, 2, 0, 3],
         [3, 0, 3, 0]],

        [[1, 0, 1, 3],
         [3, 0, 1, 1],
         [2, 2, 2, 1]]])

# mask
tensor([[[False, False, False, False],
         [False, False, False, False],
         [False,  True, False,  True]],

        [[ True, False,  True, False],
         [ True, False, False, False],
         [False,  True, False, False]]])

# extracted_values
tensor([0, 0, 1, 1, 3, 2])

We then count how many times each value appear in extracted_values:

values, counts = torch.unique(extracted_values, return_counts=True)
print(values)
print(counts)
tensor([0, 1, 2, 3])
tensor([2, 2, 1, 1])

So we can see for these new random tensors, a and b has 2 '0ā€™s in the same position, 2 '1ā€™s in the same position, 1 ā€˜2ā€™ in the same position, and 1 ā€˜3ā€™ in the same position.

Exactly, that is how torch.unique() works.

1 Like

Thank you so much. This is so helpful. a == b will create a mask and then we can pass this mask to extract values from tensors is super cool of Pytorch. Thanks again !!!