Different gradient computations when I cache intermediate variables

In my computations I cache some intermediate values and reuse them to generate multiple losses. While this caching is correct (i.e. produces identical losses as the non-cached version), I have noticed that the gradients are sometimes slightly off. The problem is worse in high dimensions (off by 1e-6 in some cases). I am not able to convince myself that this is working as intended.

Here’s a toy example with a single one-dimension parameter. See the comment before the grad() function for details:


import torch
from functools import partial

Var = torch.autograd.Variable
Param = torch.nn.Parameter
torch.manual_seed(3)

def myprint(prefix, orig, var):
    print prefix, orig.data.view(1, -1).numpy().tolist(), 
       ", gradient:", var.contiguous().data.view(1, -1).numpy().tolist()

""" Loss = torch.sum(w*x) + torch.sum((w*x)^2)
    We compute this in two ways, which give identical results, but different gradients.

    A) f = w*x     # We cache f and reuse it in the two losses
        L1 = torch.sum(f)
        L2 = torch.sum(torch.pow(f, 2))
        Loss = L1 + L2

    B) f1 = w*x
        L1 = torch.sum(f1)
        f2 = w*x  # Compute again
        L2 = torch.sum(torch.pow(f2, 2))
        Loss = L1 + L2
 """

def grad():
    w = Param(torch.Tensor([0.8032760620117188]))
    x = Var(torch.Tensor([[0.17483338713645935]]))
    print "W", w.data.view(1, -1).numpy().tolist()
    print "X", x.data.view(1, -1).numpy().tolist()
    print "\n"
    f = w * x
    f.register_hook(partial(myprint, "Method A: w*x", f))

    l1 = torch.sum(f)
    l1.register_hook(partial(myprint, "Method A: l1", l1))
    l2 = torch.sum(torch.pow(f, 2))
    l2.register_hook(partial(myprint, "Method A: l2", l2))
    loss = l1 + l2

    loss.backward()
    g1 = w.grad.data.clone()
    print "\n"

    # Second way of computing the loss.
    w.grad.data.zero_()
    f1 = w * x
    f1.register_hook(partial(myprint, "Method B: w*x (1)", f1))
    l1 = torch.sum(f1)
    l1.register_hook(partial(myprint, "Method B: l1", l1))

    f2 = w * x
    f2.register_hook(partial(myprint, "Method B: w*x (2)", f2))
    l2 = torch.sum(torch.pow(f2, 2))
    l2.register_hook(partial(myprint, "Method B: l2", l2))

    loss = l1 + l2
    loss.backward()
    g2 = w.grad.data.clone()

    print "\n"
    print "Method A: gradient w.r.t. W=", g1.numpy().tolist()
    print "Method B: gradient w.r.t. W=", g2.numpy().tolist()

grad()

This produces:

W [[0.8032760620117188]]
X [[0.17483338713645935]]

Method A: l1 [[0.1404394805431366]] , gradient: [[1.0]]
Method A: l2 [[0.01972324773669243]] , gradient: [[1.0]]
Method A: w*x [[0.1404394805431366]] , gradient: [[1.280879020690918]]

Method B: l1 [[0.1404394805431366]] , gradient: [[1.0]]
Method B: l2 [[0.01972324773669243]] , gradient: [[1.0]]
Method B: wx (1) [[0.1404394805431366]] , gradient: [[1.0]]
Method B: w
x (2) [[0.1404394805431366]] , gradient: [[0.2808789610862732]]

Method A: gradient w.r.t. W= [0.2239404171705246]
Method B: gradient w.r.t. W= [0.2239404022693634]

Note that the two gradients at the end are slightly different. As I said, the problem is worse in higher dimensions.

Any proof that the problem is worse in high dimension? I don’t think that it will matter much. This is just numerical precision issue.

No theoretical proof, just empirical observation. Here’s an illustration with the toy problem above, as I increase the dimensionality. I don’t think it is just a numerical precision issue, since the discrepancies are >1e-6 as you can see below.

dim=1,  torch.max(torch.abs(gradient1 - gradient2)) = 1.49011611938e-08
dim=10,  torch.max(torch.abs(gradient1 - gradient2)) = 1.49011611938e-08
dim=50,  torch.max(torch.abs(gradient1 - gradient2)) = 4.76837158203e-07
dim=100,  torch.max(torch.abs(gradient1 - gradient2)) = 9.53674316406e-07
dim=200,  torch.max(torch.abs(gradient1 - gradient2)) = 1.90734863281e-06
dim=500,  torch.max(torch.abs(gradient1 - gradient2)) = 3.81469726562e-06
dim=1000,  torch.max(torch.abs(gradient1 - gradient2)) = 1.90734863281e-06

Unfortunately it does matter significantly. In my experiments on real datasets, the accuracies are off by 1-2% after just a few thousand rounds of training.

well you are taking the max… of course higher numel will have higher max error. I double mean will make too much a difference.

Is one approach consistently better than the other? If so, which one?

You are right that the mean (as well as the median of nonzero elements) in the gradient diff remains small (=O(1e-8)), independent of the dimensionality. Unfortunately, roughly 30-40% of the elements in the gradient vector are different in the two approaches, so a lot of parameters start getting slightly different updates as training progresses.

In my case, the cached version gets consistently worse accuracies than the non-cached version (67.5% vs 69% on my dev set).

I did some more digging of the toy example at dimensionality=1.
As in the toy example, I took Loss = wx + (wx)^2.
So the gradient, dL/dw = x + 2wx^2

I set:
w = 0.8032760620117188
x = 0.17483338713645935

With caching, I get gradient = 0.2239404171705246 from pytorch
Without caching, I get gradient = 0.2239404022693634 from pytorch
And if I compute the gradient myself in the python interpreter, I get gradient = 0.22394040524488334

So all of them are different, and within 1e-8 of each other.
If this is indeed to do with numerical precision issues, these issues cascade as training goes by.
As I mentioned earlier, the discrepancy can be 1e-6 for some elements in the gradient, which take the corresponding parameter along a significantly different optimization path.

I am not sure what the way ahead is at this point. I would sure love to use caching but the loss in accuracy is quite significant.

@SimonW @grahul

We are facing the same issue. Any solutions for this one?

No I couldn’t find one. I ultimately disabled caching intermediate outputs since the lower accuracy of the cached version was not acceptable to me. In my case the computation overhead of the non-cached version was not too significant since my bottlenecks were elsewhere.