Binary 2D mask implementation

Hi,

I have two sets of labels, for example
labels_a = torch.Tensor([0,1,2,3]) with a shape of torch.size([4])
labels_b = torch.Tensor([0,1,2,3,0,1,3,1,2,3,2]) with shape of torch.size([11])

I was wondering if there is an efficient way to make a binary mask of shape [labels_a_size,labels_b_size] which corresponds to [4,11] in which mask[i,j] = 1 if labels_a[i] == labels_b[j] and 0 otherwise?

Also, this is my first post so let me know if there is anything I can improve upon

Broadcasting is the superpower you need

mask = (labels_a[:, None] == labels_b[None, :])

Beast regards

Thomas