Hi,
I have an NxM matrix (A) and another NxM matrix (B). The values of B are labels for values of A.
If I have C number of labels in B. Is there a simple GPU-friendly operation that takes the average of all elements of A that have the same labels in B, and then maps them into a vector with the size of Cx1? Can I make it without loops?
A=[5 6 7 9 10]
B=[0 0 0 1 1]
(5+6+7)/3=6
(9+10)=9.5
C=[6 9.5]
Hi Yegane!
One approach is to extend both A
and B
along a new “class” dimension
(using expand()
and one_hot()
, respectively). The one-hotted version
of the labels tensor is used as a class mask for the values tensor so that
per-class sums and counts can be computed, and from them, averages.
Here is such a function that computes the per-class averages, together
with some test cases:
import torch
print (torch.__version__)
_ = torch.manual_seed (2022)
# could be simpler if data has a fixed dimensionality or shape
def class_average (data, labels):
assert data.shape == labels.shape
nClass = labels.max() + 1
dims = [-1] * data.dim() + [nClass]
dex = data.unsqueeze (-1).expand (dims)
lex = torch.nn.functional.one_hot (labels)
sums = (lex * dex).sum (list (range (data.dim())))
cnts = lex.sum (list (range (data.dim())))
return sums / cnts # nan for labels that don't occur
A = torch.tensor ([5, 6, 7, 9, 10])
B = torch.tensor ([0, 0, 0, 1, 1])
clAB = class_average (A, B)
print ('A:', A)
print ('B:', B)
print ('class_average (A, B):', clAB)
m = 3
n = 5
nClass = 4
data = torch.randn (m, n)
labels = torch.randint (nClass, (m, n))
cldl = class_average (data, labels)
print ('class_average (data, labels):', cldl)
cl_missing_label = class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4]))
print ('class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4])):', cl_missing_label)
Here is the test output:
1.10.2
A: tensor([ 5, 6, 7, 9, 10])
B: tensor([0, 0, 0, 1, 1])
class_average (A, B): tensor([6.0000, 9.5000])
class_average (data, labels): tensor([ 0.4032, 0.1719, 0.1518, -1.1578])
class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4])): tensor([0., 1., 2., nan, 3.])
Best.
K. Frank
1 Like