Hello,
I have tensor with values in [-1, 0, 1, 2, 3, 4, 5]
I want to create mask for each value.
So far I did:
cats = [-1, 0, 1, 2, 3, 4, 5]]
masks = [labels == cat_id for cat_id in cats]
masks = torch.stack(masks, dim=1).int()
Do you know about faster and cleaner way to implement it?
Thanks in advance,
Shon