Calculate the mean of a torch tensor, ignoring the zero paddings

Hi all,

I am trying to calculate the mean of a tensor element-wise but ignoring the zero paddings, is this possible in Pytorch?

For example if I have the following tensor:

a = torch.FloatTensor([[1, 3, 6, 0, 9 ],
[1, 2, 3, 4, 0])

torch.mean(a,dim=0)

#result
tensor([1.0000, 2.5000, 4.5000, 2.0000, 4.5000])

result which I need

tensor([1.0000, 2.5000, 4.5000, 4.000, 9.000])

Any help is greatly appreciated
Ian

a.sum(0)/(a!=0).sum(0)