Hi,

I have an NxM matrix (A) and another NxM matrix (B). The values of B are labels for values of A.

If I have C number of labels in B. Is there a simple GPU-friendly operation that takes the average of all elements of A that have the same labels in B, and then maps them into a vector with the size of Cx1? Can I make it without loops?

A=[5 6 7 9 10]

B=[0 0 0 1 1]

(5+6+7)/3=6

(9+10)=9.5

C=[6 9.5]

Hi Yegane!

One approach is to extend both `A`

and `B`

along a new â€śclassâ€ť dimension

(using `expand()`

and `one_hot()`

, respectively). The one-hotted version

of the labels tensor is used as a class mask for the values tensor so that

per-class sums and counts can be computed, and from them, averages.

Here is such a function that computes the per-class averages, together

with some test cases:

```
import torch
print (torch.__version__)
_ = torch.manual_seed (2022)
# could be simpler if data has a fixed dimensionality or shape
def class_average (data, labels):
assert data.shape == labels.shape
nClass = labels.max() + 1
dims = [-1] * data.dim() + [nClass]
dex = data.unsqueeze (-1).expand (dims)
lex = torch.nn.functional.one_hot (labels)
sums = (lex * dex).sum (list (range (data.dim())))
cnts = lex.sum (list (range (data.dim())))
return sums / cnts # nan for labels that don't occur
A = torch.tensor ([5, 6, 7, 9, 10])
B = torch.tensor ([0, 0, 0, 1, 1])
clAB = class_average (A, B)
print ('A:', A)
print ('B:', B)
print ('class_average (A, B):', clAB)
m = 3
n = 5
nClass = 4
data = torch.randn (m, n)
labels = torch.randint (nClass, (m, n))
cldl = class_average (data, labels)
print ('class_average (data, labels):', cldl)
cl_missing_label = class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4]))
print ('class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4])):', cl_missing_label)
```

Here is the test output:

```
1.10.2
A: tensor([ 5, 6, 7, 9, 10])
B: tensor([0, 0, 0, 1, 1])
class_average (A, B): tensor([6.0000, 9.5000])
class_average (data, labels): tensor([ 0.4032, 0.1719, 0.1518, -1.1578])
class_average (torch.arange (4), torch.tensor ([0, 1, 2, 4])): tensor([0., 1., 2., nan, 3.])
```

Best.

K. Frank

1 Like