How to do mean under mask area in a memory efficent way?

suppose I have a mask tensor (1 or 0) M of shape [ N, H, W ], and a value tensor P [H, W, C], I want to use the N mask to get all values at all 1’s locations in value tensor and mean them get a mean value under mask tensor of shape [N, C], for now I can only achieve this by epand M to be MM of shape [N, H , W, 1] and P to be PP of shape [1, H, W, C] and then do MM*PP to get a result tensor V of shape[N, H, W, C], and then do


but this takes to much GPU memory as N is too large, so I can only do it on cpu, but that is extremely slow, can someone help?