Suppose I get some labels from a super-pixel algorithm, then I want to calculate the mean of SP region. Here comes the problem: how to perform batched_masked_select below?

``````import torch

a = torch.tensor([[0,0,1,1,1],[0,0,1,1,2],[3,3,3,3,2]])
c = 2
feat = torch.randn(c, *a.size())

num_sr = a.max().long()

output = []
# How to batch this op?
for i in range(num_sr+1):
output.append(torch.masked_select(feat, a==i).view(c, -1).mean(-1, keepdim=True)) # (2*1), mean of each SP.

output = torch.stack(output)
``````

@ptrblck Can you help me with this one, it is a bottleneck of my model.

@Zhaoyi-Yan @ptrblck Is there any solution ？Thank you!

I want to get the masked elements of each row in 2D tensor. But the current method is slower :

``````x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])

ones = torch.tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])

masks = torch.tensor([[ True, False, False, False,  True],
[ True, False,  True,  True, False]])

for i in range(x.size(0)):
ones[i][:s.size(0)] = s

print(f'ones:{ones}')

#output:
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])

``````

How to let the for loop to batch?

@melike @Zhaoyi-Yan @ptrblck Is there any solution ？Thank you!

@Daniel_Wang Does `x` always have values equal to or greater than `1.` ?

@melike yes, in this case i want to remove ‘2’ in x, ‘1’ is the mask ID.

``````x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])

#convert to:

tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
``````