How to get a batch of binary masks from segmentation map without using `for` loop

Hi, suppose I have a segmentation map a with dimension of torch.Size([1, 1, 6, 6])

print(a)
tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.],
          [ 0., 15., 15., 16., 16.,  0.],
          [ 0., 15., 15., 16., 16.,  0.],
          [ 0., 13., 13.,  9.,  9.,  0.],
          [ 0., 13., 13.,  9.,  9.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.]]]])

How can I get the binary masks for each id without using for loop? The binary masks should have a dimension of torch.Size([1, 4, 6, 6]), currently im doing something like this and the reason I want it without for loop is that the dimension of a might change and there might be more/less classes. Thanks.

a1 = torch.where(segmentation_a == 15, 1, 0)
a2 = torch.where(segmentation_a == 16, 1, 0)
a3 = torch.where(segmentation_a == 13, 1, 0)
a4 = torch.where(segmentation_a == 9, 1, 0)
b = torch.cat((a1, a2, a3, a4), dim=1)

print(b)
tensor([[[[0, 0, 0, 0, 0, 0],
          [0, 1, 1, 0, 0, 0],
          [0, 1, 1, 0, 0, 0],
          [0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0]]],


        [[[0, 0, 0, 0, 0, 0],
          [0, 0, 0, 1, 1, 0],
          [0, 0, 0, 1, 1, 0],
          [0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0]]],


        [[[0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0],
          [0, 1, 1, 0, 0, 0],
          [0, 1, 1, 0, 0, 0],
          [0, 0, 0, 0, 0, 0]]],


        [[[0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0],
          [0, 0, 0, 1, 1, 0],
          [0, 0, 0, 1, 1, 0],
          [0, 0, 0, 0, 0, 0]]]])

Does torch.nn.functional.one_hot work for you?

import torch
import time
a = torch.tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.],
                    [ 0., 15., 15., 16., 16.,  0.],
                    [ 0., 15., 15., 16., 16.,  0.],
                    [ 0., 13., 13.,  9.,  9.,  0.],
                    [ 0., 13., 13.,  9.,  9.,  0.],
                    [ 0.,  0.,  0.,  0.,  0.,  0.]]]])
# using CPU timing so playing fast and loose with warmup/syncs
t1 = time.time()
a1 = torch.where(a == 15, 1, 0)
a2 = torch.where(a == 16, 1, 0)
a3 = torch.where(a == 13, 1, 0)
a4 = torch.where(a == 9, 1, 0)
b = torch.cat((a1, a2, a3, a4), dim=1)
t2 = time.time()
o = torch.nn.functional.one_hot(a.long())
t3 = time.time()
print(torch.all(a1 ==  o[:,:,:,:,15]))
print(torch.all(a2 ==  o[:,:,:,:,16]))
print(torch.all(a3 ==  o[:,:,:,:,13]))
print(torch.all(a4 ==  o[:,:,:,:,9]))
print(f"loop took {t2-t1} one_hot took {t3-t2}")
tensor(True)
tensor(True)
tensor(True)
tensor(True)
loop took 0.0006623268127441406 one_hot took 0.0002810955047607422

Hi @eqy thanks for the idea, I just found another solution:

b = torch.where(a == torch.tensor([15, 16, 13, 9]).view(1, 4, 1, 1), 1, 0)