Gradients are different on single and double precision

I am implementing my own grad check function as autograd’s built-in one is slow on large tensor input and outputs. What I do is choose some random points on the input tensor and perturbe them to calculate finite differences, then I compare the results with automatic gradient. Here is my implementation,

def grad_check(model, data, target, loss_fn, eps=1e-3, control_num=3, x_list = None, y_list = None):

    num_grads = []
    orig_grads = []
    output_original = model(data)
    loss_original = loss_fn(torch.squeeze(output_original), torch.squeeze(target))
    print('Loss original: ', loss_original)
    loss_original.backward()
    grad_original = torch.squeeze(data.grad)
    (shape_x, shape_y) = torch.squeeze(data).shape[0], torch.squeeze(data).shape[1]
    if x_list is None and y_list is None:
        x_list = np.random.randint(0, shape_x, size=control_num)
        y_list = np.random.randint(0, shape_y, size=control_num)
    with torch.no_grad():
        for i in range(control_num):
            data_copy = data.clone()
            data_copy[x_list[i],y_list[i]] = data_copy[x_list[i],y_list[i]] + eps
            output_ptb = model(data_copy)
            loss_ptb = loss_fn(torch.squeeze(output_ptb), torch.squeeze(target))
            print('Loss numeric: ', loss_ptb)
            grad_num = (loss_ptb-loss_original)/eps
            num_grads.append(grad_num)
            orig_grads.append(grad_original[x_list[i],y_list[i]])
    print('Numerical grad: ', num_grads)
    print('Original grad: ', orig_grads)
    print('Ratio: ', np.array(num_grads)/np.array(orig_grads))
    return x_list, y_list

Then I run this function with the following simple model:

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        out = (x+1)/2
        return out

When I check the gradients of the above simple model (actually a function without learnable parameters just for demonstration), I observe that while using single precision, gradients don’t match but they match if I use double precision:

# Double calculations
loss_function =nn.MSELoss(reduction="sum")
mymodel = TestModel().double()
input_tensor = torch.randn(140, 120, dtype=torch.double,  requires_grad=True)
target_tensor = torch.randn(140, 120, dtype=torch.double,  requires_grad=True)

x_list, y_list = grad_check(mymodel, input_tensor, target_tensor, loss_function, eps=1e-3, control_num=5)

Results with double precision:

Numerical grad:  [tensor(-0.4685, dtype=torch.float64), tensor(0.6474, dtype=torch.float64), tensor(-0.5744, dtype=torch.float64), tensor(-0.3790, dtype=torch.float64)]
Original grad:  [tensor(-0.4687, dtype=torch.float64), tensor(0.6472, dtype=torch.float64), tensor(-0.5746, dtype=torch.float64), tensor(-0.3792, dtype=torch.float64)]
Ratio:  [0.99946664 1.00038628 0.99956495 0.9993408 ]

Single precision model:

# Single precision calculations
mymodel = TestModel()
input_tensor = torch.randn(140, 120, requires_grad=True)
target_tensor = torch.randn(140, 120, requires_grad=True)

grad_check(mymodel, input_tensor, target_tensor, loss_function, eps=1e-3, control_num=5, x_list=x_list,
           y_list=y_list)

Results with single precision:

Numerical grad:  [tensor(-1.9531), tensor(0.), tensor(1.9531), tensor(1.9531)]
Original grad:  [tensor(-0.4842), tensor(1.2964), tensor(2.9638), tensor(1.1156)]
Ratio:  [4.033863  0.        0.6590007 1.7507753]

What is the reason of this difference? Should I not consider the finite difference grad check to verify if my gradients are true or would that difference create a problem when training my model?

Hi Ahmet!

The lesson here is that calculating gradients numerically can be nuanced and
requires thoughtful attention to floating-point arithmetic.

Two things to pay attention to:

Your data (and target and the output of your model) are of order 1, and
your individual squared-error values are of order 1. However, the number
of elements in data is of order 1.e4, so your summed loss function
(MSELoss(reduction="sum")) is also of order 1.e4.

Your eps is 1.e-3 and it the change it produces in an individual contribution
to your loss is of order 1.e-3, so it produces a relative change in your loss
function of 1.e-3 / 1.e4 = 1.e-7. This is the precision of a single-precision
floating-point number, so floating-point round-off error could easily introduce
a 100% error into the finite difference you calculate.

So, using single-precision arithmetic, you get numerical gradients that are
basically completely wrong.

On the other hand, the precision of double-precision arithmetic is about
1.e-15, which is precise enough to do a solid job of computing your
numerical gradients with acceptable accuracy.

The short answer – for the use case you posted – is to use double-precision
for your numerical gradients.

If you want to use single-precision, you will need to increase eps (which will
introduce larger non-round-off errors into your gradient computation), or
reduce the number of elements in data (so that the perturbation from any
one element doesn’t get swamped by the total summed value of all the others).

Best.

K. Frank

2 Likes

Got it, thanks for your detailed answer! I have an additional question:

I see that because of the precision of the single-precision floating numbers, numerical errors can be wrong. Would it also affect the autograd’s gradient calculations? For such kind of data and loss function, should I use double-precision always?

Hi Ahmet!

No*, because autograd does not use numerical differentiation to calculate
gradients.

In more detail, when an autograd-supporting function is written, it’s derivative
is calculated analytically, and programmed as a companion “backward”
function. When .backward() is called to compute gradients, that analytic
derivative is evaluated numerically, and then numerically chained together
with the rest of the computation graph by autograd using calculus’s chain
rule.

By way of example:

The analytic derivative of x**p is p * x**(p - 1). When torch.pow() was
written, not only was the “forward” function that computes x**p programmed,
but the companion “backward” function that computes p * x**(p - 1) was
programmed, as well.

Consider:

>>> import torch
>>> torch.__version__
'1.10.2'
>>> x = torch.tensor ([2.0], requires_grad = True)
>>> y = torch.pow (x, 3)
>>> y.backward()
>>> x.grad
tensor([12.])

When y.backward() is called, autograd uses pow()'s backward function
to evaluate the analytic derivative, 3 * 2.0**(3 - 1), numerically, but the
differentiation itself is not performed numerically.

*) For any given use case you may need more or less precision to perform
the forward pass adequately, and you may need more or less precision to
perform the backward pass – autograd’s gradient computation – adequately.

For many use cases, single-precision is fine. Some use cases might require
double-precision (but not because autograd performs numerical differentiation,
because it doesn’t). Some use cases might require only half-precision, and
for some use cases, Nvidia’s reduced-precision (and misleading-named)
TensorFloat-32 might suffice.

Best.

K. Frank

1 Like