As we known there is no gradient at 0 for sqrt function.
What are possible good methods that could fix it?
I tried add a small number epsilon 1e-4, but the gradient at 1e-4 for sqrt
is 50. I thought it might be an evil thing.
Are there any good methods?
import torch
a = torch.FloatTensor([0.0, 1.0])
a.requires_grad = True
b = torch.sqrt(a)
b.sum().backward()
print("a has grad", a.requires_grad)
print("a grad", a.grad)
import torch
a = torch.FloatTensor([0.0, 1.0])
a += 1e-4 # add here
a.requires_grad = True
b = torch.sqrt(a)
b.sum().backward()
print("a has grad", a.requires_grad)
print("a grad", a.grad)
output
:
('a has grad', True)
('a grad', tensor([ 50.0000, 0.5000]))