Multiple equals

Hello,

I have tensor with values in [-1, 0, 1, 2, 3, 4, 5]
I want to create mask for each value.

So far I did:

cats =  [-1, 0, 1, 2, 3, 4, 5]]
masks = [labels == cat_id for cat_id in cats]
masks = torch.stack(masks, dim=1).int()

Do you know about faster and cleaner way to implement it?

Thanks in advance,
Shon