I am trying to create a custom loss function to train an autoencoder for image generation. In particular, I want to symmetrize the BCELoss() function. My attempt is as follows:
import torch.nn.functional as F from torch import nn class symmBCELoss(nn.BCELoss): def forward(self, input: Tensor, target: Tensor) -> Tensor: return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) + F.binary_cross_entropy(target, input, weight=self.weight, reduction=self.reduction)
However, when I try to use this as a loss function to train my network I get the error:
RuntimeError: the derivative for 'target' is not implemented
I assume this is because the first argument of F.binary_cross_entropy() is hardcoded to be the input and not the target, so the appropriate gradients don’t match up. Is there a preferred way to set up this custom loss function class in a way that maximizes inheritance from the existing class? Just from the way the source code is laid out it seems like there are useful optimizations in the current implementation of the BCE loss and I would like to take advantage of them if possible.