How to AMP with BCE that does **not** directly operate on `sigmoid` results

In my model, I have:

logits1, gate1, logits2, gate2 = net(x)
gated_prob1 = F.sigmoid(logits1) ** F.sigmoid(gate1)
gated_prob2 = F.sigmoid(logits2) ** F.sigmoid(gate2)
F.binary_cross_entropy(gated_prob1 * gated_prob2, gt)

The operation errors when enabling amp:

RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss.  binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.

What can I do since I am not simply a sigmoid layer before that BCE?

Based on your posted code snippet you are explicitly using F.sigmoid before passing these outputs to F.binary_cross_entropy, which is the functional form of nn.BCELoss:

logits1, gate1, logits2, gate2 = net(x)
gated_prob1 = F.sigmoid(logits1) ** F.sigmoid(gate1)
gated_prob2 = F.sigmoid(logits2) ** F.sigmoid(gate2)
F.binary_cross_entropy(gated_prob1 * gated_prob2, gt)

Thanks ptrblck,
However, I cannot combine F.sigmoid with BCE into BCE with logits, since I have done some other operations between them.

In that case you might want to disable autocast for these operations, but note that even without amp nn.BCELoss has less numerical stability than nn.BCEWithLogitsLoss.
See e.g. this recent topic.

Thanks ptrblck,
After some trials and errors, I am now excluding the sigmoid post-processing (the code snippet before) as well as the last linear layer from autocast. It seems to be working.