Binary mask output by network

I think link could give us a hint, in your case,

mask = torch.relu(torch.sign(torch.sigmoid(model(x))-0.5))

should return mask with elements ∈ {0,1}.

Besides, as discussed in link, the derivative of sign(.) is always 0,

suppose y = M(x1) * H(x2), where

  • M(): mask layer
  • H(): some hidden layer

image

as

image

Note that, since M(x1) ∈ {0,1}, thus only the positive mask layer outputs take part in back-propagation.

b.t.w. how to insert equations in pytorch forum?

1 Like