I think link could give us a hint, in your case,
mask = torch.relu(torch.sign(torch.sigmoid(model(x))-0.5))
should return mask with elements ∈ {0,1}.
Besides, as discussed in link, the derivative of sign(.) is always 0,
suppose y = M(x1) * H(x2), where
- M(): mask layer
- H(): some hidden layer
as
Note that, since M(x1) ∈ {0,1}, thus only the positive mask layer outputs take part in back-propagation.
b.t.w. how to insert equations in pytorch forum?