Groupby aggregate mean in pytorch

I have a 2D tensor:

    samples = torch.Tensor([
                     [0.1, 0.1],    #-> group / class 1
                     [0.2, 0.2],    #-> group / class 2
                     [0.4, 0.4],    #-> group / class 2
                     [0.0, 0.0]     #-> group / class 0
              ])

and a label for each sample corresponding to a class:

labels = torch.LongTensor([1, 2, 2, 0])

so len(samples) == len(labels). Now I want to calculate the mean for each class / label. Because there are 3 classes (0, 1 and 2) the final vector should have dimension [n_classes, samples.shape[1]] So the expected solution should be:

    result == torch.Tensor([
                     [0.1, 0.1],
                     [0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
                     [0.0, 0.0]
              ])

Question: Can this be done in pure pytorch (i.e. no numpy so that I can autograd) and ideally without for loops?

You could use scatter_add_ and torch.unique to get a similar result.
However, the result tensor will be sorted according to the class index:

samples = torch.Tensor([
                     [0.1, 0.1],    #-> group / class 1
                     [0.2, 0.2],    #-> group / class 2
                     [0.4, 0.4],    #-> group / class 2
                     [0.0, 0.0]     #-> group / class 0
              ])

labels = torch.LongTensor([1, 2, 2, 0])
labels = labels.view(labels.size(0), 1).expand(-1, samples.size(1))

unique_labels, labels_count = labels.unique(dim=0, return_counts=True)

res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, samples)
res = res / labels_count.float().unsqueeze(1)
4 Likes

worked like a charm thanks a lot! I also posted on stackoverflow and got the following alternative which I leave here for future reference:

M = torch.zeros(labels.max()+1, len(samples))
M[labels, torch.arange(4)] = 1
M = torch.nn.functional.normalize(M, p=1, dim=1)
torch.mm(M, samples)
5 Likes

That’s also an interesting approach! Thanks for sharing.

If your labels are sparse, like: [1, 2, 2, 2] (where 0 is missing), the first solution doesn’t work. The second one works but outputs the mean for label 0, here is my fix:


def mean_by_label(samples, labels):
    ''' select mean(samples), count() from samples group by labels order by labels asc '''
    weight = torch.zeros(labels.max()+1, samples.shape[0]).to(samples.device) # L, N
    weight[labels, torch.arange(samples.shape[0])] = 1
    label_count = weight.sum(dim=1)
    weight = torch.nn.functional.normalize(weight, p=1, dim=1) # l1 normalization
    mean = torch.mm(weight, samples) # L, F
    index = torch.arange(mean.shape[0])[label_count > 0]
    return mean[index], label_count[index]

Hi thanks for sharing your code! Do you know how to caculate the variance for each label (in a similar way)?

As previous solutions do not work for the case of sparse groups (e.g., not all the groups are in the data), I made one :slight_smile:

def groupby_mean(value:torch.Tensor, labels:torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
    """Group-wise average for (sparse) grouped tensors
    
    Args:
        value (torch.Tensor): values to average (# samples, latent dimension)
        labels (torch.LongTensor): labels for embedding parameters (# samples,)
    
    Returns: 
        result (torch.Tensor): (# unique labels, latent dimension)
        new_labels (torch.LongTensor): (# unique labels,)
        
    Examples:
        >>> samples = torch.Tensor([
                             [0.15, 0.15, 0.15],    #-> group / class 1
                             [0.2, 0.2, 0.2],    #-> group / class 3
                             [0.4, 0.4, 0.4],    #-> group / class 3
                             [0.0, 0.0, 0.0]     #-> group / class 0
                      ])
        >>> labels = torch.LongTensor([1, 5, 5, 0])
        >>> result, new_labels = groupby_mean(samples, labels)
        
        >>> result
        tensor([[0.0000, 0.0000, 0.0000],
            [0.1500, 0.1500, 0.1500],
            [0.3000, 0.3000, 0.3000]])
            
        >>> new_labels
        tensor([0, 1, 5])
    """
    uniques = labels.unique().tolist()
    labels = labels.tolist()

    key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
    val_key = {val: key for key, val in zip(uniques, range(len(uniques)))}
    
    labels = torch.LongTensor(list(map(key_val.get, labels)))
    
    labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))
    
    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, value)
    result = result / labels_count.float().unsqueeze(1)
    new_labels = torch.LongTensor(list(map(val_key.get, unique_labels[:, 0].tolist())))
    return result, new_labels
2 Likes