Batch masked select implementation

Suppose I get some labels from a super-pixel algorithm, then I want to calculate the mean of SP region. Here comes the problem: how to perform batched_masked_select below?

import torch

a = torch.tensor([[0,0,1,1,1],[0,0,1,1,2],[3,3,3,3,2]])
c = 2
feat = torch.randn(c, *a.size())

num_sr = a.max().long()

output = []
# How to batch this op?
for i in range(num_sr+1):
    output.append(torch.masked_select(feat, a==i).view(c, -1).mean(-1, keepdim=True)) # (2*1), mean of each SP.

output = torch.stack(output)

@ptrblck Can you help me with this one, it is a bottleneck of my model.

@Zhaoyi-Yan @ptrblck Is there any solution ?Thank you!

I want to get the masked elements of each row in 2D tensor. But the current method is slower :

x = torch.tensor([[1., 2., 2., 2., 3.],
        [1., 2., 4., 3., 2.]])

ones = torch.tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

masks = torch.tensor([[ True, False, False, False,  True],
        [ True, False,  True,  True, False]])

for i in range(x.size(0)):
    mask = masks[i]
    s = torch.masked_select(x[i], mask)
    ones[i][:s.size(0)] = s

print(f'ones:{ones}')

#output: 
tensor([[1., 3., 1., 1., 1.],
        [1., 4., 3., 1., 1.]])


How to let the for loop to batch?

@melike @Zhaoyi-Yan @ptrblck Is there any solution ?Thank you!

@Daniel_Wang Does x always have values equal to or greater than 1. ?

@melike yes, in this case i want to remove ‘2’ in x, ‘1’ is the mask ID.

x = torch.tensor([[1., 2., 2., 2., 3.],
        [1., 2., 4., 3., 2.]])

#convert to:

tensor([[1., 3., 1., 1., 1.],
        [1., 4., 3., 1., 1.]])