Compute the weighted average in PyTorch

Yes, that’s correct. To write it shorter:

weighted_average = (A@W)/W.sum()