All param.grad Tensors are 0

I am implementing [2010.02803] A Transformer-based Framework for Multivariate Time Series Representation Learning, a BERT-like model. I wrote my own model+boilerplate, and am using a masked MSE loss for a reconstruction task (fully reconstruct partially masked input):

class MaskedMSE(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, pred, trgt, mask):
        # mask is 1 where loss should be considered, 0 where it should be ignored
        return torch.sum(((pred-trgt)*mask)**2) / torch.sum(mask)

I won’t post my full encoder code here unless requested, because its a bit long, but basically its a standard BERT encoder, quite like the code from the annotated transformer, for example.
My problem is this: When I do

criterion = MaskedMSE()
...
X_masked = X * mask
pred = encoder(X_masked)
loss = criterion(pred, X, (1-mask))
loss.backward()
optimizer.zero_grad()

The loss is always nonzero, starts at ~2, and goes slowly down very quickly to below 0.05.
However, none of the model parameters are actually updated, as I checked with code that amounts to the following:

def backward_debug(model, grad_in, grad_out):
    for name, param in model.named_parameters():
        assert param.requires_grad
        if param.grad is not None:
            assert (param.grad == 0).all()
            print(name)
            print(param.grad)

model.register_backward_hook(backward_debug)

The above asserts are never triggered, but I get the following output:

encoder.embedding.biastensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
encoder.pe.Wtensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
encoder.layers.0.feed_forward.pwff_layer.3.weight
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')

And so on, with output for every model parameter I expect output for, but all param.grad jacobians are 0. Putting a print of pred and X between the backward and zero_grad calls prints outputs like these:

         [ 0.1115,  0.0279,  0.0978,  0.0056,  0.0397,  0.0876,  0.0712],
         [ 0.1026,  0.0656,  0.1129,  0.0509,  0.0648,  0.0761,  0.0474]],

        [[ 0.1898, -0.0031,  0.0609,  0.0998,  0.0651,  0.0619,  0.0335],
         [ 0.1938,  0.0143,  0.0640,  0.0557,  0.0802,  0.0757,  0.0318],
         [ 0.2150,  0.0122,  0.0867,  0.0741,  0.0503,  0.0815,  0.0508]],

        [[ 0.0960,  0.0362,  0.0922,  0.0348,  0.0468,  0.0695,  0.0698],
         [ 0.1095,  0.0387,  0.1283,  0.0274,  0.0552,  0.0903,  0.0641],
         [ 0.0869,  0.0406,  0.0850,  0.0371,  0.0317,  0.0653,  0.0617]],

        [[ 0.0991,  0.0295,  0.1165,  0.0415,  0.0758,  0.0876,  0.0706],
         [ 0.1119,  0.0651,  0.1074,  0.0335,  0.0549,  0.0648,  0.0618],
         [ 0.1339,  0.0716,  0.1297,  0.0626,  0.0726,  0.1032,  0.0689]]],
       device='cuda:0', grad_fn=<SliceBackward>)
prediction ^
tensor([[[0.0553, 0.0246, 0.0600, 0.0302, 0.0360, 0.0604, 0.0372],
         [0.0467, 0.0232, 0.0661, 0.0344, 0.0356, 0.0457, 0.0358],
         [0.0673, 0.0506, 0.0729, 0.0439, 0.0541, 0.0466, 0.0300]],

        [[0.1160, 0.0262, 0.0553, 0.0529, 0.0380, 0.0583, 0.0191],
         [0.1212, 0.0249, 0.0523, 0.0389, 0.0312, 0.0450, 0.0206],
         [0.1347, 0.0441, 0.0904, 0.0564, 0.0331, 0.0580, 0.0336]],

        [[0.0502, 0.0265, 0.0547, 0.0262, 0.0466, 0.0671, 0.0368],
         [0.0558, 0.0262, 0.0525, 0.0298, 0.0598, 0.0469, 0.0346],
         [0.0624, 0.0611, 0.0829, 0.0289, 0.0638, 0.0433, 0.0358]],

        [[0.0488, 0.0383, 0.0676, 0.0357, 0.0358, 0.0680, 0.0493],
         [0.0540, 0.0374, 0.0652, 0.0331, 0.0357, 0.0613, 0.0473],
         [0.0664, 0.0497, 0.0839, 0.0452, 0.0477, 0.0551, 0.0455]]],
       device='cuda:0')
X ^ 

I know there is no way to be sure what the problem is without my full model code + training boilerplate, but perhaps there’s some obvious thing I am overlooking, or typical failure modes, such as NaNs in the input (there are none) wrong device, or something like that.

I’m not sure, if registering a backward hook in the modules would be triggered before the gradients are accumulated, so the zero gradients might be expected.
Could you check the .grad attributes of all parameters directly via:

for param in model.parameters():
    param.register_hook(lambda grad: print(grad))

instead?

You are right, this prints out nonzero gradients. I misunderstood when backward hooks are called then. Thank you!