Masked median filter/pool

I am trying to write a masked median filter/pool based on this code
PyTorch MedianPool (MedianFilter) (github.com)

mask is as simple as all nonzero() values in the input data.

Question is how to apply mask after unfold that is along the last dimension of
x.contiguous().view(x.size()[:4] + (-1,))
so that median is taken over only non-masked values?

I can gather, median and scatter along the last dimension but that would imply that I loop over other dimensions?

Thanks!