Compute the weighted average in PyTorch

Assuming we have a tensor A, with shape

``````CxH
``````

And tensor W, representing the weights, with shape `H`

Tensor A represents some features with dim `C` for each spatial location in `H` (flattened), and weights `W` represents the weight for each spatial location. Instead of computing the mean via:

``````torch.mul(A, W).mean(1)
``````

How can we compute the weighted average ?
The output dim should be of size `C`.

Would it be:

``````Z = torch.mul(A, W)
Weighted_average = torch.sum(Z, dim=1) / torch.sum(W)``````

Yes, that’s correct. To write it shorter:

``````weighted_average = (A@W)/W.sum()
``````