The Loss can't back propagate to model's parameters with my customized loss function

I designed a customized loss:

class CustomIndicesEdgeAccuracyLoss(torch.nn.Module):
    def __init__(self, num_classes: int, selected_indices: list):
        super(CustomIndicesEdgeAccuracyLoss, self).__init__()
        self.num_classes = num_classes
        self.selected_indices = selected_indices

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        batch_size, num_classes, feature_size = input.shape
        selected_input = input[::, ::, self.selected_indices]
        selected_target = target[::, self.selected_indices]
        selected_preds = torch.argmax(selected_input, dim=1)
        edge_acc = torch.eq(selected_preds, selected_target).sum()/torch.numel(selected_preds)
        loss = 1 – edge_acc
        loss.requires_grad = True

        return loss

But the loss won’t back propagate to model’s parameters, in other word, the gradient of model’s parameters are always 0 and the model’s parameters can’t be updated.
What’s possible reasons? How should I revise the codes?

Here is some information of the local variables of forward():

input.shape: torch.Size([64, 3, 5])
target.shape:torch.Size([64, 5])
selected_input.shape: torch.Size([64, 3, 2]) 
selected_target.shape:torch.Size([64, 2])

Indeed, the gradients are zero because there is no single mathematical operation in your loss.
Actually, the only reason you don’t get an exception is very likely this loss.requires_grad = True cos otherwise I even doubt you have a backprop graph.

Note that indexing (input[::, ::, self.selected_indices]) will just take some elements, namely, gradients for all the non-selected elements will be zero.
Argmax is same thing, only the selected element (the max element) will backpropagate a copy of the upstream gradients.
Finally, torch.eq is neither differentiable operation. I’m not even sure it has a backprop method implemented.
Also, I believe (although I don’t recall it right now) that loss.requires_grad = True may even reset the graph.

1 Like