Compute the weighted average in PyTorch

Assuming we have a tensor A, with shape


And tensor W, representing the weights, with shape H

Tensor A represents some features with dim C for each spatial location in H (flattened), and weights W represents the weight for each spatial location. Instead of computing the mean via:

torch.mul(A, W).mean(1) 

How can we compute the weighted average ?
The output dim should be of size C.

Would it be:

Z = torch.mul(A, W)
Weighted_average = torch.sum(Z, dim=1) / torch.sum(W)

Yes, that’s correct. To write it shorter:

weighted_average = (A@W)/W.sum()