Bug in backward hook: grad_input[1] (dW) is none

I have this simple class but dW is None, even when the parameter is still being updated. After calling x.require_grad_(True), the hook is printing dW. This behavior is confusing

import torch.nn as nn
import torch.nn.functional as F
import torch

def hook_fn(module, grad_input, grad_output):
    print(f"{module} has dW {grad_input[1]} and scaler value {module.scaler}")
def factor_balance(mid_blksz, out_blksz):
    total = mid_blksz * out_blksz

class Scaler(nn.Module):
    def __init__(self, out_features):
        self.scaler = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        # x.requires_grad_(True)
        x = self.scaler * x
        # x = F.layer_norm(x, x.shape[1:])
        # layernorm to avoid vanishing gradient
        return x

x = torch.ones(100, 100, dtype=torch.float32)
y = torch.full((100, 100), 2, dtype=torch.float32)
model = Scaler(100)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for i in range(100):
    loss = F.mse_loss(model(x), y)

Thanks in advance!

In your case need to look at grad_input[0], grad_input[1] corresponds to the grad wrt x, which does not require grad, and hence receives None for gradients.

Also, note that Module.register_backward_hook is deprecated, you should be using Module.register_full_backward_hooks instead.

I thinks it’s a bug because it returns both dw and dx for some modules. grad_input[0] has the same shape as x. Using register_full_backward_hooks makes no difference.
@ptrblck could you please take a look?

It returns both dw and dx because your module’s last operation is a mul between x and w. This is expected behavior.

Sorry full backward hook does something different (this is also expected behavior). It only returns the gradients wrt the inputs to your module, it will not include gradients wrt params

If you want the gradient wrt to a weight, you could register a hook directly to the parameter.

If you would like the hook to be fired after gradients wrt all parameters to be fired, you can register a multi-hook.

1 Like