How to fix gradient at 0 for sqrt function

As we known there is no gradient at 0 for sqrt function.
What are possible good methods that could fix it?
I tried add a small number epsilon 1e-4, but the gradient at 1e-4 for sqrt is 50. I thought it might be an evil thing.
Are there any good methods?

import torch

a = torch.FloatTensor([0.0, 1.0])
a.requires_grad = True

b = torch.sqrt(a)
b.sum().backward()
print("a has grad", a.requires_grad)
print("a grad", a.grad)
import torch

a = torch.FloatTensor([0.0, 1.0])
a += 1e-4 # add here
a.requires_grad = True

b = torch.sqrt(a)
b.sum().backward()
print("a has grad", a.requires_grad)
print("a grad", a.grad)

output:

('a has grad', True)
('a grad', tensor([ 50.0000,   0.5000]))

It is correct.

a = a1, a2 = sqrt(10^-4), sqrt(1+10^-4) = 10^-2, 1 + O(10^-4)
b = sqrt(a1), sqrt(a2)
Sum b = sqrt(a1) + sqrt(a2)
grad(b) = deriv(b)_(a1), deriv(b)_(a2) = 1/(2*sqrt(a1)), 1/(2*sqrt(a2)) = 
= 1/(2*10^-2), 1/(2*(1 + O(10^-4))) = 
= 50, 0.5 + O(10^-4)

the second term is correct up to the written precision.

1 Like