Hi pytorch friends!
I’m trying to implement fast gradient penalty using forward and backward hooks but found that for gradients of gradients, hooks show a slightly aberrant behavior. Hopefully, you can help me find where I go wrong.
I tried to construct a minimal example that shows the behavior. Let me start by hooking in a simple linear model with one parameter.
import numpy as np
import torch
from torch import nn
from torch import autograd
torch.manual_seed(42)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(1, 1, bias=False)
self.lin.weight = nn.Parameter(torch.tensor([[2.0]]))
def forward(self, x):
# Switch comments in the following lines for an aberrant behavior.
return self.lin(x)
# return torch.log(torch.exp(self.lin(x)))
model = Model()
# Add hooks to construct per-example gradients
gradients = {}
activations = {}
def forward_hook(layer, inputs, outputs):
activations[layer] = inputs[0].detach()
def backward_hook(layer, grad_input, grad_output):
A = activations[layer]
B = grad_output[0].detach()
gradients[layer] = torch.einsum("ni,nj->nij", B, A)
model.lin.register_forward_hook(forward_hook)
model.lin.register_backward_hook(backward_hook)
The backward hook computes the gradient specifically for a layer of type nn.Linear
. Let’s feed the model with a batch of 2 input examples, and compute parameter gradients. We also check that the per-example gradients are computed correctly. But we’ll comment that out later.
x = torch.FloatTensor([2, 1])[..., None]
x.requires_grad_(True)
loss = model(x).sum()
# We'll comment this out later
loss.backward()
for layer, g in gradients.items():
assert torch.allclose(layer.weight.grad, g.sum(dim=0))
# .. until here.
So the hooks seem to work, and I was pretty happy with myself having successfully stolen the trick of @Yaroslav_Bulatov and his colleagues.
However, I need to actually compute gradient of loss
with respect to the input data x
, and construct a cost function from that.
grads = autograd.grad(
outputs=loss, inputs=x,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
grad_loss = (grads ** 2).sum()
model.zero_grad()
grad_loss.backward()
Unfortunately the test fails now:
grad_loss = (grads ** 2).sum()
model.zero_grad()
grad_loss.backward()
for layer, g in gradients.items():
assert layer.weight.grad == 8, layer.weight.grad
assert g.sum(dim=0) == 3, g
The situation gets even more confusing when you change the comments in Model.forward
. Note that we take log(exp(x)) which should do anything, but subsequently, the hooked gradients are all zero.
Here’s a notebook.
Please help