I have a tensor of size BxCxHxW. In the tensor there are some valid values which is larger than 0 and invalid values which equals to 0. Now I want to compute the mean of valid values in dimension CxHxW. So the result size is Bx1, the ith element equals the mean of valid values in the ith CxHxW tensor.

For example if I ignore C, I have a 2x1x3 tensor [[[0,1,2]], [[5,7,0]]], which means B=2, H=1, W=3. The desired result is a 2x1 tensor [1.5, 6], in which 1.5 is the average of 1 and 2, 6 is the average of 5 and 7. If I directly use tensor.mean(-1).mean(-1) it will count zeros into the result, which is not what I want.

if it may happen, that you mask out everything (e.g. if you have multiple batches and some batchentries are completly masked out), you might want to do it like this: