Hello.
Suppose I have a tensor of class labels of arbitrary shape (suppose it’s 2-dimensional, i.e., H x W)
I want to get the masks (corresponding positions) of each class as below.
# C = 3; H = 4; W = 4;
L = torch.LongTensor([[1,1,0,2],
[2,0,1,0],
[0,0,0,1],
[2,1,1,2]])
m0 = L.eq(0)
m1 = L.eq(1)
m2 = L.eq(2)
print(m0)
>>> tensor([[0, 0, 1, 0],
[0, 1, 0, 1],
[1, 1, 1, 0],
[0, 0, 0, 0]])
print(m1)
>>> tensor([[1, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 1, 1, 0]])
print(m2)
>>> tensor([[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 1]])
Currently, there are two ways that I can do this:
#1. use a for loop : memory efficient, but slow
for c in range(C):
m = L.eq(c) # m.shape: [H,W]
do_something(m)
#2. fast, but huge memory consumption
classes = torch.arange(C).view(C,1,1)
print(classes)
>>> tensor([[[0]],
[[1]],
[[2]]])
sparse_mask = L.eq(classes) # a sparse mask of shape [C,H,W]
do_something(sparse_mask)
The problem I am dealing with has : H=1,000, W=10,000, C=5,000.
Using #1 is very slow since it has to loop the do_something 5,000 times.
On the other hand, #2 is almost prohibitive due to the memory consumption.
Is there any good solution for this?
p.s.: While writing this, I noticed that there might be a way to exploit the sparsity of the mask in #2, but I do not have much knowledge about using sparse operations…