Hi, I’m trying to implement a f_score loss function. From [https://gist.github.com/SuperShinyEyes/dcc68a08ff8b615442e3bc6a9b55a354] GitHub page I took the implementation of the required function, and tweaked it a bit to suit my needs. But when i ran the code, I get that the gradients are None, even though I’ve set the requires_grad
to True in the first variable, as mentioned in https://discuss.pytorch.org/t/runtimeerror-element-0-of-variables-does-not-require-grad-and-does-not-have-a-grad-fn/11074.
Some specifics -
- I need the output from the network to be a binary vector, as it represents activities.
- The quantity im measuring is prediction of actions, hence the need for the f score loss.
Here is a basic example:
import torch
class FScoreLoss(torch.nn.Module):
def __init__(self, eps=1e-7):
super().__init__()
self.eps = eps
def forward(self, y_true, y_pred, beta, grad=True):
print(f'y_true = {y_true}')
print(f'y_pred = {y_pred}')
tp = (y_true * y_pred).sum().to(torch.float32)
fn = ((1 - y_true) * y_pred).sum().to(torch.float32)
fp = (y_true * (1 - y_pred)).sum().to(torch.float32)
print(f'tp = {tp}, fn = {fn}, fp = {fp}')
precision = tp / (tp + fp + self.eps)
recall = tp / (tp + fn + self.eps)
f_score_loss = (1 + beta ** 2) * (precision * recall) / ((beta**2)*precision + recall + self.eps)
print(f'precision = {precision}, recall = {recall}, f_score = {f_score_loss}')
return f_score_loss
f_score_loss_func = FScoreLoss()
model = torch.nn.Linear(10, 10)
x = torch.randn(1, 10).requires_grad_(True)
y_true = (torch.randn(1, 10)[0] > .5).float()
y_pred = (model(x)>.5).float().requires_grad_(True)
loss = f_score_loss_func(y_true, y_pred, beta=2., grad=True)
loss.backward()
print(f'y_pred.grad = {y_pred.grad}')
print(f'model.weight.grad = {model.weight.grad}')
As can be seen from the output (i.e. last two lines), :
y_pred.grad = tensor([[-0.4082, 0.3061, -0.4082, -0.4082, -0.4082, -0.4082, 0.3061, -0.4082,
0.3061, -0.4082]])
model.weight.grad = None
the gradients are being calculated, but not updated in the model.
Appreciate any help. Thanks