We can divide it into 2 steps:
a == b
returns a boolean tensor (a mask) where values are True
if the both a
and b
has the same value. The good thing is that with PyTorch, this operation is performed element-wise. So it checks each item at each channel, column, row and performs this operation.
Once we have this mask, we can use it to select items from the original tensors. It doesnβt matter if we were to do a[a == b]
or b[a == b]
since we only extract the values where a
and b
are the same.
We can see how this works in practice:
a = torch.randint(0, 4, [2,3,4])
b = torch.randint(0, 4, [2,3,4])
print(a)
print(b)
mask = a == b
print(mask)
extracted_values = a[mask]
print(extracted_values)
Output:
# a
tensor([[[0, 1, 2, 3],
[3, 3, 2, 2],
[1, 0, 0, 0]],
[[1, 2, 1, 1],
[3, 2, 3, 3],
[3, 2, 0, 3]]])
# b
tensor([[[3, 3, 0, 0],
[1, 2, 0, 3],
[3, 0, 3, 0]],
[[1, 0, 1, 3],
[3, 0, 1, 1],
[2, 2, 2, 1]]])
# mask
tensor([[[False, False, False, False],
[False, False, False, False],
[False, True, False, True]],
[[ True, False, True, False],
[ True, False, False, False],
[False, True, False, False]]])
# extracted_values
tensor([0, 0, 1, 1, 3, 2])
We then count how many times each value appear in extracted_values
:
values, counts = torch.unique(extracted_values, return_counts=True)
print(values)
print(counts)
tensor([0, 1, 2, 3])
tensor([2, 2, 1, 1])
So we can see for these new random tensors, a and b has 2 '0βs in the same position, 2 '1βs in the same position, 1 β2β in the same position, and 1 β3β in the same position.
Exactly, that is how torch.unique()
works.