Masked median filter/pool

I am trying to write a masked median filter/pool based on this code
PyTorch MedianPool (MedianFilter) (

mask is as simple as all nonzero() values in the input data.

Question is how to apply mask after unfold that is along the last dimension of
x.contiguous().view(x.size()[:4] + (-1,))
so that median is taken over only non-masked values?

I can gather, median and scatter along the last dimension but that would imply that I loop over other dimensions?