If you donβt need to backpropagate through it, you could just apply a threshold on the sigmoid output of e.g. 0.5.
Do you just need the binary outputs for some accuracy calculation or visualization?
model = nn.Conv2d(3, 1, 3, 1, 1)
mask = torch.sigmoid(model(x)) # but I need mask to be binary instead of values between 0 and 1
masked_img = mask * x # point-wise
loss = cal_loss(masked_img, x)
loss.mean().backward()
I think there might be a slight mistake in the example. Not important because directionally correct but I believe that instead of being: loss.mean().backward()
you should write: