Conditional operation on a NxM tensor based on values of another NxM tensor(labels)

Hi,
I have an NxM matrix (A) and another NxM matrix (B). The values of B are labels for values of A.
If I have C number of labels in B. Is there a simple GPU-friendly operation that takes the average of all elements of A that have the same labels in B, and then maps them into a vector with the size of Cx1? Can I make it without loops?
A=[5 6 7 9 10]
B=[0 0 0 1 1]
(5+6+7)/3=6
(9+10)=9.5
C=[6 9.5]

Hi Yegane!

One approach is to extend both A and B along a new “class” dimension
(using expand() and one_hot(), respectively). The one-hotted version
of the labels tensor is used as a class mask for the values tensor so that
per-class sums and counts can be computed, and from them, averages.

Here is such a function that computes the per-class averages, together
with some test cases:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

# could be simpler if data has a fixed dimensionality or shape

def class_average (data, labels):
    assert  data.shape == labels.shape
    nClass = labels.max() + 1
    dims = [-1] * data.dim() + [nClass]
    dex = data.unsqueeze (-1).expand (dims)
    lex = torch.nn.functional.one_hot (labels)
    sums = (lex * dex).sum (list (range (data.dim())))
    cnts = lex.sum (list (range (data.dim())))
    return  sums / cnts   # nan for labels that don't occur

A = torch.tensor ([5, 6, 7, 9, 10])
B = torch.tensor ([0, 0, 0, 1, 1])
clAB = class_average (A, B)

print ('A:', A)
print ('B:', B)
print ('class_average (A, B):', clAB)

m = 3
n = 5
nClass = 4

data = torch.randn (m, n)
labels = torch.randint (nClass, (m, n))
cldl = class_average (data, labels)

print ('class_average (data, labels):', cldl)

cl_missing_label = class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4]))
print ('class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4])):', cl_missing_label)

Here is the test output:

1.10.2
A: tensor([ 5,  6,  7,  9, 10])
B: tensor([0, 0, 0, 1, 1])
class_average (A, B): tensor([6.0000, 9.5000])
class_average (data, labels): tensor([ 0.4032,  0.1719,  0.1518, -1.1578])
class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4])): tensor([0., 1., 2., nan, 3.])

Best.

K. Frank

1 Like