Greyscale output to threshold binary tensor - how to write forward pass

Say the output of my forward pass is a greyscale image

x = torch.rand(2, 2)

The gold label is the binary image

gold = torch.tensor([[0, 1], [1, 0]])

I want to define my loss by comparing the binary threshold of this image against the gold label. How do I write this so that the computation graph is defined correctly. Here is what I came up with, but my loss doesn’t decrease

x = (x - torch.min(x)) / (torch.max(x) - torch.min(x))
x[x < 0.5] = 0
x[x >= 0.5] = 1

Loss is defined as

loss = criterion(x, gold)

The threshold operation would create a zero gradient on almost everywhere, which is most likely why the model isn’t training.
The usual loss functions expect the model output to contain the logits, log-probabilities, etc. and apply a “smooth” function internally, such as F.log_softmax.