Global average pooling excluding zeros

I am trying to find the average of a tensor (b, c, h, w), which would transform it to (b, c, 1, 1). It is the same as average pooling and can be performed with adaptivepool.
However, I am trying to exclude all zeros in the average calculation. Is there any way to do this?

tensor=tensor.view(b,c,-1)
mask = tensor>0
nelements = mask.sum(dim=-1)
pooling = tensor.sum(dim=-1)/nelements
1 Like

wow thank you very much