Check if n-th tensor row contains n-th vector element

Hi. I have the following problem. Suppose I have a tensor M of shape (N, 5) and a vector v of size (N).
I am looking for a vectorized solution that would tell me if the nth row of M contains the nth element of v. Example:

M = torch.tensor([
    [0, 1, 2, 3, 4],
    [7, 5, 3, 4, 8],
    [0, 1, 9, 7, 5],
    [1, 4, 7, 2, 5],
])

v = torch.tensor([0, 3, 2, 6])

then the function I am looking for should output something like the following

does_intersect(M, v) == torch.tensor([1, 1, 0, 0])  # or boolean indicators

I know that I can just iterate over the rows and check if v[n] in M[n] but I want a vectorized solution.
One additional note, each row of M contains unique elements

You could unsqueeze v compare the tensors directly (v will be broadcasted) and then check the columns for any match:

print((M == v.unsqueeze(1)).any(1))
> tensor([ True,  True, False, False])
1 Like

Thanks, couldn’t figure it out