Got "Element 0 of tensors" when using torch.where() in custom loss

Hi all, I tried to build a custom loss function,and found that if using torch.where() in loss operation the error " Element 0 of tensors does not require grad and does not have a grad_fn" raised when I start training the model.And if I just remove torch.where() the training of model works.

class Custom_loss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, y_hat, y):
        prob = torch.abs((y_hat - y)/y)
        t = torch.ones_like(prob)
        f = torch.zeros_like(prob)
        prob = torch.where(prob < 0.1, t, f)
        return torch.sum(prob)

Is there a way to get around it?

Thanks!

I guess it might come from the fact that t and f don’t require grad.

import torch
prob = torch.randn(5)
prob.requires_grad_(True)
t = torch.ones_like(prob)
f = torch.zeros_like(prob)
p = torch.where(prob < 0.3, t, f)
print(p.requires_grad) # prints False

However if you set requires_grad to True for t and f, then:

import torch
prob = torch.randn(5)
prob.requires_grad_(True)
t = torch.ones_like(prob)
f = torch.zeros_like(prob)
t.requires_grad_(True)
f.requires_grad_(True)
p = torch.where(prob < 0.3, t, f)
print(p.requires_grad) # prints True

PS: I supposed you meant f = torch.zeros_like(prob) instead of f = torch.ones_like(prob)

1 Like

Thank you @LeviViana, the model can be trained after setting requires_grad to True for t and f,
but the loss doesn’t change,I will look into this. By the way, should I put a negative sign in front of torch.sum(prob) to make it a minimum optimization?

The problem with this snippet is that there is no gradient for the variable prob. Indeed, the first problem was related to the fact that your loss didn’t require gradient, and it was fixed by enabling gradients for t and f. Now the problem is that it doesn’t work right.

import torch
prob = torch.rand(5)
prob.requires_grad_(True)
t = torch.ones_like(prob)
f = torch.zeros_like(prob)
t.requires_grad_(True)
f.requires_grad_(True)
p = torch.where(prob < 0.3, t, f)
print(p.requires_grad) # prints True
loss = (p ** 2).sum() # Meaningless loss
loss.backward()
print(prob.grad) # No gradients
print(t.grad) # Gradients
print(f.grad) # Gradients

Maybe the following hack should work better for you:

import torch
prob = torch.rand(5)
prob.requires_grad_(True)
max_prob = torch.max(prob)
t = torch.ones_like(prob) * max_prob
f = torch.zeros_like(prob)
p = torch.where(prob < 0.3, t, f) / max_prob
print(p.requires_grad) # prints True
loss = (p ** 2).sum() # Meaningless loss
loss.backward()
print(prob.grad) # Gradients
print(t.grad) # No gradients
print(f.grad) # No gradients
1 Like

@LeviViana, this hack worked a few times,it’s a shame I didn’t set seed, so the success never reproduced.When I used torch.mean(prob) as substitute, the loss always decreases gradually in iterations.So I think maybe it just too hard to train the model if we turn loss into 0,1 type.

Thank you again for everything you’ve done

1 Like

Try analyzing the magnitude of your gradients. Maybe your loss is suffering from some numerical problems. Then, try changing the formula of your loss without changing its objective. For instance, it is well known that it’s generally better to use nn.LogSofmax instead of nn.Softmax. This thread discusses this observed property.