Torch.eq for multiple values

Given a 2d tensor A, and a 1d tensor B of length > 1:

A = [ [1, 2], [1, 4], [4, 3]]
B = [1, 4]

I would like to write a function that returns (in this case) two masks of the same shape as A, where the mask is 1 if the value of B is the same as the value of A:

res = [
[[1., 0.],[1., 0.],[0., 0.]], 
[[0., 0.],[0., 1.],[1., 0.]]
]

a naive solution is to use a for loop on B, which performs a torch.eq between A and every element of B.

But is there a smarter (and faster) way to do this?

You want broadcasting, best described here: Broadcasting semantics — PyTorch 1.12 documentation.

Best regards

Thomas