Is there a good way to do this kind of label operation?

Hello.

Suppose I have a tensor of class labels of arbitrary shape (suppose it’s 2-dimensional, i.e., H x W)

I want to get the masks (corresponding positions) of each class as below.

# C = 3; H = 4; W = 4;
L = torch.LongTensor([[1,1,0,2],
                      [2,0,1,0],
                      [0,0,0,1],
                      [2,1,1,2]])
m0 = L.eq(0)
m1 = L.eq(1)
m2 = L.eq(2)

print(m0)
>>> tensor([[0, 0, 1, 0],
            [0, 1, 0, 1],
            [1, 1, 1, 0],
            [0, 0, 0, 0]])

print(m1)
>>> tensor([[1, 1, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
            [0, 1, 1, 0]])

print(m2)
>>> tensor([[0, 0, 0, 1],
            [1, 0, 0, 0],
            [0, 0, 0, 0],
            [1, 0, 0, 1]])

Currently, there are two ways that I can do this:

#1. use a for loop : memory efficient, but slow
for c in range(C):
    m = L.eq(c)     # m.shape: [H,W]
    do_something(m)
#2. fast, but huge memory consumption
classes = torch.arange(C).view(C,1,1)
print(classes)
>>> tensor([[[0]],

            [[1]],

            [[2]]])

sparse_mask = L.eq(classes)  # a sparse mask of shape [C,H,W]
do_something(sparse_mask)

The problem I am dealing with has : H=1,000, W=10,000, C=5,000.

Using #1 is very slow since it has to loop the do_something 5,000 times.

On the other hand, #2 is almost prohibitive due to the memory consumption.

Is there any good solution for this?

p.s.: While writing this, I noticed that there might be a way to exploit the sparsity of the mask in #2, but I do not have much knowledge about using sparse operations…