Suppose I get some labels from a super-pixel algorithm, then I want to calculate the mean of SP region. Here comes the problem: how to perform batched_masked_select below?
import torch
a = torch.tensor([[0,0,1,1,1],[0,0,1,1,2],[3,3,3,3,2]])
c = 2
feat = torch.randn(c, *a.size())
num_sr = a.max().long()
output = []
# How to batch this op?
for i in range(num_sr+1):
output.append(torch.masked_select(feat, a==i).view(c, -1).mean(-1, keepdim=True)) # (2*1), mean of each SP.
output = torch.stack(output)