In my computations I cache some intermediate values and reuse them to generate multiple losses. While this caching is correct (i.e. produces identical losses as the non-cached version), I have noticed that the gradients are sometimes slightly off. The problem is worse in high dimensions (off by 1e-6 in some cases). I am not able to convince myself that this is working as intended.
Here’s a toy example with a single one-dimension parameter. See the comment before the grad() function for details:
import torch
from functools import partial
Var = torch.autograd.Variable
Param = torch.nn.Parameter
torch.manual_seed(3)
def myprint(prefix, orig, var):
print prefix, orig.data.view(1, -1).numpy().tolist(),
", gradient:", var.contiguous().data.view(1, -1).numpy().tolist()
""" Loss = torch.sum(w*x) + torch.sum((w*x)^2)
We compute this in two ways, which give identical results, but different gradients.
A) f = w*x # We cache f and reuse it in the two losses
L1 = torch.sum(f)
L2 = torch.sum(torch.pow(f, 2))
Loss = L1 + L2
B) f1 = w*x
L1 = torch.sum(f1)
f2 = w*x # Compute again
L2 = torch.sum(torch.pow(f2, 2))
Loss = L1 + L2
"""
def grad():
w = Param(torch.Tensor([0.8032760620117188]))
x = Var(torch.Tensor([[0.17483338713645935]]))
print "W", w.data.view(1, -1).numpy().tolist()
print "X", x.data.view(1, -1).numpy().tolist()
print "\n"
f = w * x
f.register_hook(partial(myprint, "Method A: w*x", f))
l1 = torch.sum(f)
l1.register_hook(partial(myprint, "Method A: l1", l1))
l2 = torch.sum(torch.pow(f, 2))
l2.register_hook(partial(myprint, "Method A: l2", l2))
loss = l1 + l2
loss.backward()
g1 = w.grad.data.clone()
print "\n"
# Second way of computing the loss.
w.grad.data.zero_()
f1 = w * x
f1.register_hook(partial(myprint, "Method B: w*x (1)", f1))
l1 = torch.sum(f1)
l1.register_hook(partial(myprint, "Method B: l1", l1))
f2 = w * x
f2.register_hook(partial(myprint, "Method B: w*x (2)", f2))
l2 = torch.sum(torch.pow(f2, 2))
l2.register_hook(partial(myprint, "Method B: l2", l2))
loss = l1 + l2
loss.backward()
g2 = w.grad.data.clone()
print "\n"
print "Method A: gradient w.r.t. W=", g1.numpy().tolist()
print "Method B: gradient w.r.t. W=", g2.numpy().tolist()
grad()
This produces:
W [[0.8032760620117188]]
X [[0.17483338713645935]]Method A: l1 [[0.1404394805431366]] , gradient: [[1.0]]
Method A: l2 [[0.01972324773669243]] , gradient: [[1.0]]
Method A: w*x [[0.1404394805431366]] , gradient: [[1.280879020690918]]Method B: l1 [[0.1404394805431366]] , gradient: [[1.0]]
Method B: l2 [[0.01972324773669243]] , gradient: [[1.0]]
Method B: wx (1) [[0.1404394805431366]] , gradient: [[1.0]]
Method B: wx (2) [[0.1404394805431366]] , gradient: [[0.2808789610862732]]Method A: gradient w.r.t. W= [0.2239404171705246]
Method B: gradient w.r.t. W= [0.2239404022693634]
Note that the two gradients at the end are slightly different. As I said, the problem is worse in higher dimensions.