Masked_select per sample in batch

The input tensor is a NCHW dim tensor, and I got a mask tensor with size N1HW indicates every pixel’s label.(0 or 1).I’d like to group within each sample by this mask and do some matrix manipulation.such as
0 group and 1 group’s inner product. Is there any good idea about how to group per sample in a batch?