Compare two tensors and return total number of matches and the matching values

Hi all,

I have two same-size 3D tensors, each one is populated with 4 possible values (0,1,2,3)

I want to compare these two tensors element-wise and return the total number of matches where both tensors had 0 in the same index, and the total number of matches where both tensors had 1 in the same index etc until the matching value is 3.

Any efficient way of doing this? Thank you!

Does it need to be differentiable?

Otherwise:

a = torch.randint(0, 4, [2,3,4])
b = torch.randint(0, 4, [2,3,4])
print(a)
print(b)

print(torch.unique(a[a == b], return_counts=True))

output:

tensor([[[2, 1, 2, 1],
         [0, 0, 0, 0],
         [0, 0, 0, 1]],

        [[1, 3, 2, 3],
         [3, 3, 0, 0],
         [3, 1, 3, 3]]])

tensor([[[0, 3, 1, 3],
         [2, 0, 3, 3],
         [2, 0, 3, 1]],

        [[2, 2, 2, 1],
         [0, 2, 3, 0],
         [1, 0, 0, 2]]])

(tensor([0, 1, 2]), tensor([3, 1, 1]))
1 Like

Thank you so much. This is what I needed. Just a quick question, can you please tell me what does
a[a == b] do? I have never seen it before.

Also, in your example, if there is no matching for 3, so it would not return a value for that?

Thanks again!

We can divide it into 2 steps:

a == b returns a boolean tensor (a mask) where values are True if the both a and b has the same value. The good thing is that with PyTorch, this operation is performed element-wise. So it checks each item at each channel, column, row and performs this operation.

Once we have this mask, we can use it to select items from the original tensors. It doesn’t matter if we were to do a[a == b] or b[a == b] since we only extract the values where a and b are the same.

We can see how this works in practice:

a = torch.randint(0, 4, [2,3,4])
b = torch.randint(0, 4, [2,3,4])
print(a)
print(b)

mask = a == b
print(mask)

extracted_values = a[mask]
print(extracted_values)

Output:

# a
tensor([[[0, 1, 2, 3],
         [3, 3, 2, 2],
         [1, 0, 0, 0]],

        [[1, 2, 1, 1],
         [3, 2, 3, 3],
         [3, 2, 0, 3]]])

# b
tensor([[[3, 3, 0, 0],
         [1, 2, 0, 3],
         [3, 0, 3, 0]],

        [[1, 0, 1, 3],
         [3, 0, 1, 1],
         [2, 2, 2, 1]]])

# mask
tensor([[[False, False, False, False],
         [False, False, False, False],
         [False,  True, False,  True]],

        [[ True, False,  True, False],
         [ True, False, False, False],
         [False,  True, False, False]]])

# extracted_values
tensor([0, 0, 1, 1, 3, 2])

We then count how many times each value appear in extracted_values:

values, counts = torch.unique(extracted_values, return_counts=True)
print(values)
print(counts)
tensor([0, 1, 2, 3])
tensor([2, 2, 1, 1])

So we can see for these new random tensors, a and b has 2 '0’s in the same position, 2 '1’s in the same position, 1 β€˜2’ in the same position, and 1 β€˜3’ in the same position.

Exactly, that is how torch.unique() works.

1 Like

Thank you so much. This is so helpful. a == b will create a mask and then we can pass this mask to extract values from tensors is super cool of Pytorch. Thanks again !!!