I have a 2D tensor:
samples = torch.Tensor([
[0.1, 0.1], #-> group / class 1
[0.2, 0.2], #-> group / class 2
[0.4, 0.4], #-> group / class 2
[0.0, 0.0] #-> group / class 0
])
and a label for each sample corresponding to a class:
labels = torch.LongTensor([1, 2, 2, 0])
so len(samples) == len(labels)
. Now I want to calculate the mean for each class / label. Because there are 3 classes (0, 1 and 2) the final vector should have dimension [n_classes, samples.shape[1]]
So the expected solution should be:
result == torch.Tensor([
[0.1, 0.1],
[0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
[0.0, 0.0]
])
Question: Can this be done in pure pytorch (i.e. no numpy so that I can autograd) and ideally without for loops?