# Binary mask output by network

Below code will create a mask for `x`

``````class get_mask(nn.Module):
def __int__(self, in_c):
self.conv1 = nn.Conv2d(in_c, 1 , 3, 1)
self.norm = nn.BatchNorm2d(2)

def forward(self, x):
p = 1
p2d =(p, p, p, p)
x = self.normalize(x)
return F.sigmoid(x)
``````

How modify the network to output a binary mask(zeros and ones) instead of `F.sigmoid(x)`?

If you donβt need to backpropagate through it, you could just apply a threshold on the sigmoid output of e.g. `0.5`.
Do you just need the binary outputs for some accuracy calculation or visualization?

No, I need to backprop. I want to use to this `mask` to mask out invalid pixels in `x` while calculating loss.

Ah OK, then I misunderstood your question, sorry.
In that case you could create your mask vector (with zeros and ones) and multiply it with the loss.

Something like this should work:

``````model = nn.Conv2d(3, 1, 3, 1, 1)
x = torch.randn(1, 3, 5, 5)
mask = torch.empty(1, 1, 5, 5).random_(2)
target = torch.empty(1, 1, 5, 5).random_(2)
criterion = nn.BCELoss(reduction='none')

output = torch.sigmoid(model(x))
loss = criterion(output, target)
loss = loss * mask
loss.mean().backward()
``````
1 Like

Sorry, maybe I confused you.

I need something link this:

``````model = nn.Conv2d(3, 1, 3, 1, 1)
mask = torch.sigmoid(model(x)) # but I need mask to be binary instead of values between 0 and 1
masked_img = mask * x # point-wise
loss = cal_loss(masked_img, x)
loss.mean().backward()
``````

I think link could give us a hint, in your case,

``````mask = torch.relu(torch.sign(torch.sigmoid(model(x))-0.5))
``````

should return mask with elements β {0,1}.

Besides, as discussed in link, the derivative of sign(.) is always 0,

suppose y = M(x1) * H(x2), where

• M(): mask layer
• H(): some hidden layer

as

Note that, since M(x1) β {0,1}, thus only the positive mask layer outputs take part in back-propagation.

b.t.w. how to insert equations in pytorch forum?

1 Like

Thanks for this, exactly what I was looking for!

1 Like

I think there might be a slight mistake in the example. Not important because directionally correct but I believe that instead of being:
`loss.mean().backward()`
you should write:

``````loss = loss / torch.sum(mask)
loss.backward()
``````