How to check whether tensor values in a different tensor pytorch?

I have 2 tensors of unequal size

a = torch.tensor([[1,2], [2,3],[3,4]])
b = torch.tensor([[4,5],[2,3]])

I want a boolean array of whether each value exists in the other tensor without iterating. something like

a in b

and the result should be

[False, True, False]

as only the value of a[1] is in b

To reason about this we can look at the 1-d case.

If a = [2, 4] and b = [1, 2, 3], our output should be [True, False]. Intuitively we want to create some kind of 2 by 3 table some we can compute the pairwise equalities all at once.
To do this we can use torch.expand:

torch.expand(a, (3, 2)) gives us [[2, 4], [2, 4], 2, 4]]
torch.expand(b, (2, 3)) gives us [[1, 2, 3], [1, 2, 3]]

Note that here, the shapes of a and b, are (3, 2), and (2, 3), and simply transposing b will allow us to have the same shape.a1 = [[2, 4], [2, 4], [2, 4]] b1 = [[1, 1], [2, 2], [3, 3]].
As you can see here, if we now do a1 == b1, which computes as [[False, False], [True, False], [False, False]], we are simply computing the equality across all pairs in the cartesian product. If we simply find max of this tensor along dim=0 we have [True, False] as desired.

Now to generalize to cases where dimension > 1. If you have tensors a and b, by definition their shapes must be the same along all dimensions except one. For example if we want to compute a in b along dim=1 their shapes could be (x, y, z) and (x, w, z). Working backwards, we know that our output needs to be one-dimensional and have shape (y). So the the idea here could be to compute some kind of table of with y rows and w columns where the ith output corresponds max value of the ith row.

Some non-tested code below based on our intuitions above is:

a1 = torch.expand(a, (w, x, y, z))
b1 = torch.expand(b, (y, x, w, z)).permute(0, 2)
t = (a1 == b1).min(dim=(1, 3)).values # shape: (w, y)
t.max(dim=0).values # shape: (y, )

hi, excuse me i have the same problem as yours…can you find it? would you please help me?

hi, this may be help

a=torch.tensor([2,4])
b = torch.tensor([1, 2, 3])
print([(x in b)  for x in a])