Custom loss function is not decreasing

Here is my custom loss function…

        def mfe(dict_dataloader, y_hat, y, msfe=False):

            #prepare and assign variables from the dataset
            no_of_classes = dict_dataloader['n_class']
            sample_sizes = dict_dataloader['n_per_class_train']
            positive_class_sample_size = min(sample_sizes)
            positive_class = sample_sizes.index(positive_class_sample_size)

            #get y_hat labels
            _, y_hat_labels = torch.softmax(y_hat, dim=1).topk(1, dim=1)
            y_hat_labels = y_hat_labels.squeeze()

            #one-hot encode labels
            target_labels_oh = F.one_hot(y, no_of_classes).float().requires_grad_()
            y_hat_labels_oh = F.one_hot(y_hat_labels, no_of_classes).float().requires_grad_()

            #get the rows of positive condition from target tensor
            positive_class_rows = target_labels_oh[(target_labels_oh[:, positive_class]==1)].requires_grad_()

            #get the indices of rows with real positive labels
            pos_indices = (target_labels_oh[:, positive_class]==1).nonzero(as_tuple=True)[0]

            #get the corresponding rows in prediction tensor
            pos_rows_yhat = y_hat_labels_oh[pos_indices].requires_grad_()

            #FNE equation
            false_negative_error = (torch.tensor(1/positive_class_sample_size, requires_grad=True))*torch.sum(torch.sum(0.5*(positive_class_rows-pos_rows_yhat)**2, dim=1))

            #FPE equation
            neg_class_indices = [index for index, value in enumerate(sample_sizes) if value != positive_class_sample_size]
            false_positive = []

            for i in neg_class_indices:
                negative_class_rows = target_labels_oh[(target_labels_oh[:, i]==1)].requires_grad_()
                neg_indices = (target_labels_oh[:, i]==1).nonzero(as_tuple=True)[0]
                neg_rows_yhat = y_hat_labels_oh[neg_indices].requires_grad_()
                false_positive_error_per_class = (torch.tensor(1/sample_sizes[i], requires_grad=True))*torch.sum(torch.sum(0.5*(negative_class_rows-neg_rows_yhat)**2, dim=1))

            false_positive_error = sum(false_positive)


            #final calculation
            if msfe==False:
                loss = false_negative_error + false_positive_error
                loss = (false_negative_error)**2 + (false_positive_error**2)
            return loss

When i print(loss.grad_fn), print(false_negative_error.grad_fn) and print(false_positive_error.grad_fn), it shows <AddBackward0 object at 0x7f11a561f790>, <MulBackward0 object at 0x7f11a561f670> and <AddBackward0 object at 0x7f11a561f790>, respectively. However, the loss does not decrease at all with each epoch, even when I decrease the learning rate. What’s the issue here?

You are detaching the output of your model (y_hat) a few times and try to fix it by calling .requires_grad_() on it, which won’t re-attach the tensor to the computation graph, but will create a new leaf tensor without any Autograd history.

E.g. as visible in these first lines of code:

y_hat = torch.randn(10, 10, requires_grad=True)

# y_hat_labels is detached as it's an integer type
_, y_hat_labels = torch.softmax(y_hat, dim=1).topk(1, dim=1)
# None

# calling .requires_grad_() will create a new leaf
y_hat_labels_oh = F.one_hot(y_hat_labels, 10).float().requires_grad_()
# None
# True

so what do you recommend I do here? remove y_hat.requires_grad_(True) ? i tired running the model without it and i still had the same issue

You would have to rewrite the loss function such that floating point tensors are used with differentiable operations.
As shown in my code snippet the returned indices from the topk operation are not differentiable since integer types do not support gradients.

I added this piece of code

y_hat_labels = y_hat_labels.squeeze().float().requires_grad_().to(device)

but it still didn’t solve the issue. Since y_hat_labels is now float, so I changed this piece of code too to accomdate that F.one_hot takes integer type representation.

y_hat_labels_oh = F.one_hot(y_hat_labels.type(torch.int64), no_of_classes).float()

but this still did not solve the issue.

Calling requires_grad_() on an already detached tensor won’t re-attach it to the computation graph as already mentioned, so you would have to avoid breaking the computation graph.

so you are saying I need to find a different method to get the indices other than topk() as this function breaks the computation graph?

No, since using integer values will always break the computation graph.
Calling topk and using the returned indices or e.g. argmax will both detach the result from the computation graph so you would need to use the “soft” outputs instead and e.g multiply the target with these instead of indexing.

i still don’t get it, can you please write sample code for me as to how to generate y_hat_labels without the use of topk() or outputting integers? i need it to get the 1/sample_size_per_class*(y_hat_labels_one_hot - target_one_hot)**2