# Groupby aggregate mean in pytorch

I have a 2D tensor:

``````    samples = torch.Tensor([
[0.1, 0.1],    #-> group / class 1
[0.2, 0.2],    #-> group / class 2
[0.4, 0.4],    #-> group / class 2
[0.0, 0.0]     #-> group / class 0
])
``````

and a label for each sample corresponding to a class:

`labels = torch.LongTensor([1, 2, 2, 0])`

so `len(samples) == len(labels)`. Now I want to calculate the mean for each class / label. Because there are 3 classes (0, 1 and 2) the final vector should have dimension `[n_classes, samples.shape]` So the expected solution should be:

``````    result == torch.Tensor([
[0.1, 0.1],
[0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
[0.0, 0.0]
])
``````

Question: Can this be done in pure pytorch (i.e. no numpy so that I can autograd) and ideally without for loops?

You could use `scatter_add_` and `torch.unique` to get a similar result.
However, the result tensor will be sorted according to the class index:

``````samples = torch.Tensor([
[0.1, 0.1],    #-> group / class 1
[0.2, 0.2],    #-> group / class 2
[0.4, 0.4],    #-> group / class 2
[0.0, 0.0]     #-> group / class 0
])

labels = torch.LongTensor([1, 2, 2, 0])
labels = labels.view(labels.size(0), 1).expand(-1, samples.size(1))

unique_labels, labels_count = labels.unique(dim=0, return_counts=True)

res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, samples)
res = res / labels_count.float().unsqueeze(1)
``````
4 Likes

worked like a charm thanks a lot! I also posted on stackoverflow and got the following alternative which I leave here for future reference:

``````M = torch.zeros(labels.max()+1, len(samples))
M[labels, torch.arange(4)] = 1
M = torch.nn.functional.normalize(M, p=1, dim=1)
torch.mm(M, samples)
``````
5 Likes

That’s also an interesting approach! Thanks for sharing.

If your labels are sparse, like: [1, 2, 2, 2] (where 0 is missing), the first solution doesn’t work. The second one works but outputs the mean for label 0, here is my fix:

``````
def mean_by_label(samples, labels):
''' select mean(samples), count() from samples group by labels order by labels asc '''
weight = torch.zeros(labels.max()+1, samples.shape).to(samples.device) # L, N
weight[labels, torch.arange(samples.shape)] = 1
label_count = weight.sum(dim=1)
weight = torch.nn.functional.normalize(weight, p=1, dim=1) # l1 normalization
mean = torch.mm(weight, samples) # L, F
index = torch.arange(mean.shape)[label_count > 0]
return mean[index], label_count[index]
``````

Hi thanks for sharing your code! Do you know how to caculate the variance for each label (in a similar way)?

As previous solutions do not work for the case of sparse groups (e.g., not all the groups are in the data), I made one ``````def groupby_mean(value:torch.Tensor, labels:torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
"""Group-wise average for (sparse) grouped tensors

Args:
value (torch.Tensor): values to average (# samples, latent dimension)
labels (torch.LongTensor): labels for embedding parameters (# samples,)

Returns:
result (torch.Tensor): (# unique labels, latent dimension)
new_labels (torch.LongTensor): (# unique labels,)

Examples:
>>> samples = torch.Tensor([
[0.15, 0.15, 0.15],    #-> group / class 1
[0.2, 0.2, 0.2],    #-> group / class 3
[0.4, 0.4, 0.4],    #-> group / class 3
[0.0, 0.0, 0.0]     #-> group / class 0
])
>>> labels = torch.LongTensor([1, 5, 5, 0])
>>> result, new_labels = groupby_mean(samples, labels)

>>> result
tensor([[0.0000, 0.0000, 0.0000],
[0.1500, 0.1500, 0.1500],
[0.3000, 0.3000, 0.3000]])

>>> new_labels
tensor([0, 1, 5])
"""
uniques = labels.unique().tolist()
labels = labels.tolist()

key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
val_key = {val: key for key, val in zip(uniques, range(len(uniques)))}

labels = torch.LongTensor(list(map(key_val.get, labels)))

labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))

unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, value)
result = result / labels_count.float().unsqueeze(1)
new_labels = torch.LongTensor(list(map(val_key.get, unique_labels[:, 0].tolist())))
return result, new_labels
``````
2 Likes