Convolutional Autoencoder, Autocast, and loss of information from BCEWithLogits

I have a convolutional autoencoder, where the output layer is:

        self.t_conv5 = nn.ConvTranspose2d(128, 3, 3, stride=2, padding=1, output_padding=1)
        x = torch.sigmoid(self.t_conv5(x))

and the loss criterion is given by nn.BCELoss(). When trying to utilise autocast for fp16 conversion, the error that BCELoss is unsafe is given and the program suggests using BCEWithLogits instead to combine the final sigmoidal layer and BCE loss. This obviously stacks another sigmoidal layer on top of the desired output. If I remove the sigmoidal activation and re-apply it to the model output (i.e. output = model(images) becomes output = torch.sigmoid(model(images)), the calculated loss after 20 epochs of training remains the same but there’s a drastic difference in the output images:

But as this is the same data and shuffling is disabled, the rearrangement shouldn’t have changed the output in this way. Is there some loss of relevant weights by having to rearrange the location of the sigmoid function in this way, and is there some way to counteract this? The autocasting, to my knowledge, should not reduce the final accuracy.

If I understand the workflow correctly, you are now feeding the raw logits to nn.BCEWithLogitsLoss, which yields the same loss as in the float32 training, and use the output = torch.sigmoid(model(images)) tensor only to visualize the samples?
If so, do you see any difference by converting the output to .float() before feeding it into the torch.sigmoid operation?

That recovers the colours. Do you know why that is?

The mixed-precision training utility (amp) raises the error when you are using sigmoid + nn.BCEWithLogitsLoss, as its usage was determined to not be numerically stable in float16. Thus passing the raw logits to nn.BCEWithLogitsLoss is the right approach, since internally log_sigmoid will be applied, if I’m not mistaken.
I had a guess that using torch.sigmoid on the output might have exactly the same issue when you are trying to predict the outputs and it seems to be indeed the case.

By the same problem, you mean numerical instability?

Yes, numerical instability, rounding etc. which might be caused by passing float16 values directly to sigmoid.