Binary mask output by network

Below code will create a mask for x

class get_mask(nn.Module):
  def __int__(self, in_c):
    self.conv1 = nn.Conv2d(in_c, 1 , 3, 1)
    self.norm = nn.BatchNorm2d(2)

  def forward(self, x):
    p = 1
    p2d =(p, p, p, p)
    x = self.conv1(F.pad(x,p2d))
    x = self.normalize(x)
    return F.sigmoid(x)

How modify the network to output a binary mask(zeros and ones) instead of F.sigmoid(x)?

If you don’t need to backpropagate through it, you could just apply a threshold on the sigmoid output of e.g. 0.5.
Do you just need the binary outputs for some accuracy calculation or visualization?

No, I need to backprop. I want to use to this mask to mask out invalid pixels in x while calculating loss.

Ah OK, then I misunderstood your question, sorry.
In that case you could create your mask vector (with zeros and ones) and multiply it with the loss.

Something like this should work:

model = nn.Conv2d(3, 1, 3, 1, 1)
x = torch.randn(1, 3, 5, 5)
mask = torch.empty(1, 1, 5, 5).random_(2)
target = torch.empty(1, 1, 5, 5).random_(2)
criterion = nn.BCELoss(reduction='none')

output = torch.sigmoid(model(x))
loss = criterion(output, target)
loss = loss * mask
1 Like

Sorry, maybe I confused you.

I need something link this:

model = nn.Conv2d(3, 1, 3, 1, 1)
mask = torch.sigmoid(model(x)) # but I need mask to be binary instead of values between 0 and 1
masked_img = mask * x # point-wise
loss = cal_loss(masked_img, x)

I think link could give us a hint, in your case,

mask = torch.relu(torch.sign(torch.sigmoid(model(x))-0.5))

should return mask with elements ∈ {0,1}.

Besides, as discussed in link, the derivative of sign(.) is always 0,

suppose y = M(x1) * H(x2), where

  • M(): mask layer
  • H(): some hidden layer




Note that, since M(x1) ∈ {0,1}, thus only the positive mask layer outputs take part in back-propagation.

b.t.w. how to insert equations in pytorch forum?

1 Like

Thanks for this, exactly what I was looking for!

1 Like

I think there might be a slight mistake in the example. Not important because directionally correct but I believe that instead of being:
you should write:

loss = loss / torch.sum(mask)