I am trying to write a masked median filter/pool based on this code
PyTorch MedianPool (MedianFilter) (github.com)
mask is as simple as all nonzero() values in the input data.
Question is how to apply mask after unfold
that is along the last dimension of
x.contiguous().view(x.size()[:4] + (-1,))
so that median is taken over only non-masked values?
I can gather
, median
and scatter
along the last dimension but that would imply that I loop over other dimensions?
Thanks!