To reason about this we can look at the 1-d case.
If a = [2, 4]
and b = [1, 2, 3]
, our output should be [True, False]
. Intuitively we want to create some kind of 2 by 3 table some we can compute the pairwise equalities all at once.
To do this we can use torch.expand
:
torch.expand(a, (3, 2))
gives us [[2, 4], [2, 4], 2, 4]]
torch.expand(b, (2, 3))
gives us [[1, 2, 3], [1, 2, 3]]
Note that here, the shapes of a and b, are (3, 2), and (2, 3), and simply transposing b will allow us to have the same shape.a1 = [[2, 4], [2, 4], [2, 4]]
b1 = [[1, 1], [2, 2], [3, 3]]
.
As you can see here, if we now do a1 == b1
, which computes as [[False, False], [True, False], [False, False]]
, we are simply computing the equality across all pairs in the cartesian product. If we simply find max of this tensor along dim=0
we have [True, False]
as desired.
Now to generalize to cases where dimension > 1. If you have tensors a
and b
, by definition their shapes must be the same along all dimensions except one. For example if we want to compute a in b
along dim=1
their shapes could be (x, y, z)
and (x, w, z)
. Working backwards, we know that our output needs to be one-dimensional and have shape (y)
. So the the idea here could be to compute some kind of table of with y rows and w columns where the ith output corresponds max value of the ith row.
Some non-tested code below based on our intuitions above is:
a1 = torch.expand(a, (w, x, y, z))
b1 = torch.expand(b, (y, x, w, z)).permute(0, 2)
t = (a1 == b1).min(dim=(1, 3)).values # shape: (w, y)
t.max(dim=0).values # shape: (y, )