Here is my custom loss function…
def mfe(dict_dataloader, y_hat, y, msfe=False):
#prepare and assign variables from the dataset
no_of_classes = dict_dataloader['n_class']
sample_sizes = dict_dataloader['n_per_class_train']
positive_class_sample_size = min(sample_sizes)
positive_class = sample_sizes.index(positive_class_sample_size)
#get y_hat labels
y_hat.requires_grad_(True)
_, y_hat_labels = torch.softmax(y_hat, dim=1).topk(1, dim=1)
y_hat_labels = y_hat_labels.squeeze()
#one-hot encode labels
target_labels_oh = F.one_hot(y, no_of_classes).float().requires_grad_()
y_hat_labels_oh = F.one_hot(y_hat_labels, no_of_classes).float().requires_grad_()
#get the rows of positive condition from target tensor
positive_class_rows = target_labels_oh[(target_labels_oh[:, positive_class]==1)].requires_grad_()
#get the indices of rows with real positive labels
pos_indices = (target_labels_oh[:, positive_class]==1).nonzero(as_tuple=True)[0]
#get the corresponding rows in prediction tensor
pos_rows_yhat = y_hat_labels_oh[pos_indices].requires_grad_()
#FNE equation
false_negative_error = (torch.tensor(1/positive_class_sample_size, requires_grad=True))*torch.sum(torch.sum(0.5*(positive_class_rows-pos_rows_yhat)**2, dim=1))
#FPE equation
neg_class_indices = [index for index, value in enumerate(sample_sizes) if value != positive_class_sample_size]
false_positive = []
for i in neg_class_indices:
negative_class_rows = target_labels_oh[(target_labels_oh[:, i]==1)].requires_grad_()
neg_indices = (target_labels_oh[:, i]==1).nonzero(as_tuple=True)[0]
neg_rows_yhat = y_hat_labels_oh[neg_indices].requires_grad_()
false_positive_error_per_class = (torch.tensor(1/sample_sizes[i], requires_grad=True))*torch.sum(torch.sum(0.5*(negative_class_rows-neg_rows_yhat)**2, dim=1))
false_positive.append(false_positive_error_per_class)
false_positive_error = sum(false_positive)
false_positive_error.requires_grad_(True)
false_negative_error.requires_grad_(True)
#final calculation
if msfe==False:
loss = false_negative_error + false_positive_error
else:
loss = (false_negative_error)**2 + (false_positive_error**2)
loss.requires_grad_(True)
return loss
When i print(loss.grad_fn), print(false_negative_error.grad_fn) and print(false_positive_error.grad_fn), it shows <AddBackward0 object at 0x7f11a561f790>, <MulBackward0 object at 0x7f11a561f670> and <AddBackward0 object at 0x7f11a561f790>, respectively. However, the loss does not decrease at all with each epoch, even when I decrease the learning rate. What’s the issue here?