To reason about this we can look at the 1-d case.
a = [2, 4] and
b = [1, 2, 3], our output should be
[True, False]. Intuitively we want to create some kind of 2 by 3 table some we can compute the pairwise equalities all at once.
To do this we can use
torch.expand(a, (3, 2)) gives us
[[2, 4], [2, 4], 2, 4]]
torch.expand(b, (2, 3)) gives us
[[1, 2, 3], [1, 2, 3]]
Note that here, the shapes of a and b, are (3, 2), and (2, 3), and simply transposing b will allow us to have the same shape.
a1 = [[2, 4], [2, 4], [2, 4]]
b1 = [[1, 1], [2, 2], [3, 3]].
As you can see here, if we now do
a1 == b1, which computes as
[[False, False], [True, False], [False, False]], we are simply computing the equality across all pairs in the cartesian product. If we simply find max of this tensor along
dim=0 we have
[True, False] as desired.
Now to generalize to cases where dimension > 1. If you have tensors
b, by definition their shapes must be the same along all dimensions except one. For example if we want to compute
a in b along
dim=1 their shapes could be
(x, y, z) and
(x, w, z). Working backwards, we know that our output needs to be one-dimensional and have shape
(y). So the the idea here could be to compute some kind of table of with y rows and w columns where the ith output corresponds max value of the ith row.
Some non-tested code below based on our intuitions above is:
a1 = torch.expand(a, (w, x, y, z))
b1 = torch.expand(b, (y, x, w, z)).permute(0, 2)
t = (a1 == b1).min(dim=(1, 3)).values # shape: (w, y)
t.max(dim=0).values # shape: (y, )