 Groupby aggregate mean in pytorch

(Yaser Martinez Palenzuela) #1

I have a 2D tensor:

``````    samples = torch.Tensor([
[0.1, 0.1],    #-> group / class 1
[0.2, 0.2],    #-> group / class 2
[0.4, 0.4],    #-> group / class 2
[0.0, 0.0]     #-> group / class 0
])
``````

and a label for each sample corresponding to a class:

`labels = torch.LongTensor([1, 2, 2, 0])`

so `len(samples) == len(labels)`. Now I want to calculate the mean for each class / label. Because there are 3 classes (0, 1 and 2) the final vector should have dimension `[n_classes, samples.shape]` So the expected solution should be:

``````    result == torch.Tensor([
[0.1, 0.1],
[0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
[0.0, 0.0]
])
``````

Question: Can this be done in pure pytorch (i.e. no numpy so that I can autograd) and ideally without for loops?

#2

You could use `scatter_add_` and `torch.unique` to get a similar result.
However, the result tensor will be sorted according to the class index:

``````samples = torch.Tensor([
[0.1, 0.1],    #-> group / class 1
[0.2, 0.2],    #-> group / class 2
[0.4, 0.4],    #-> group / class 2
[0.0, 0.0]     #-> group / class 0
])

labels = torch.LongTensor([1, 2, 2, 0])
labels = labels.view(labels.size(0), 1).expand(-1, samples.size(1))

unique_labels, labels_count = labels.unique(dim=0, return_counts=True)

res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, samples)
res = res / labels_count.float().unsqueeze(1)
``````
(Yaser Martinez Palenzuela) #3

worked like a charm thanks a lot! I also posted on stackoverflow and got the following alternative which I leave here for future reference:

``````M = torch.zeros(labels.max()+1, len(samples))
M[labels, torch.arange(4)] = 1
M = torch.nn.functional.normalize(M, p=1, dim=1)
torch.mm(M, samples)
``````
1 Like
#4

That’s also an interesting approach! Thanks for sharing.