Groupby aggregate mean in pytorch

(Yaser Martinez Palenzuela) #1

I have a 2D tensor:

    samples = torch.Tensor([
                     [0.1, 0.1],    #-> group / class 1
                     [0.2, 0.2],    #-> group / class 2
                     [0.4, 0.4],    #-> group / class 2
                     [0.0, 0.0]     #-> group / class 0
              ])

and a label for each sample corresponding to a class:

labels = torch.LongTensor([1, 2, 2, 0])

so len(samples) == len(labels). Now I want to calculate the mean for each class / label. Because there are 3 classes (0, 1 and 2) the final vector should have dimension [n_classes, samples.shape[1]] So the expected solution should be:

    result == torch.Tensor([
                     [0.1, 0.1],
                     [0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
                     [0.0, 0.0]
              ])

Question: Can this be done in pure pytorch (i.e. no numpy so that I can autograd) and ideally without for loops?

#2

You could use scatter_add_ and torch.unique to get a similar result.
However, the result tensor will be sorted according to the class index:

samples = torch.Tensor([
                     [0.1, 0.1],    #-> group / class 1
                     [0.2, 0.2],    #-> group / class 2
                     [0.4, 0.4],    #-> group / class 2
                     [0.0, 0.0]     #-> group / class 0
              ])

labels = torch.LongTensor([1, 2, 2, 0])
labels = labels.view(labels.size(0), 1).expand(-1, samples.size(1))

unique_labels, labels_count = labels.unique(dim=0, return_counts=True)

res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, samples)
res = res / labels_count.float().unsqueeze(1)
(Yaser Martinez Palenzuela) #3

worked like a charm thanks a lot! I also posted on stackoverflow and got the following alternative which I leave here for future reference:

M = torch.zeros(labels.max()+1, len(samples))
M[labels, torch.arange(4)] = 1
M = torch.nn.functional.normalize(M, p=1, dim=1)
torch.mm(M, samples)
1 Like
#4

That’s also an interesting approach! Thanks for sharing.