Gradient update with partial ground truth

Hi all, I noticed something perplexing in my experimentation with Pytorch. I work on time series data where I don’t always have the ground truth for all the time steps for a sequence. I’m having trouble doing gradient updates. Here’s the code

import torch
rand_in = torch.rand((5,10))
loss_fn = torch.nn.MSELoss(reduction='none')
model = torch.nn.Linear(in_features=10,out_features=5).requires_grad_(True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
temp_weight = model.weight.data.clone()
print("before optim", temp_weight)
outp = model(rand_in)
optimizer.zero_grad()
gt = torch.zeros_like(outp)
gt[0,1] = torch.nan # artificially ablating ground truth
gt[1,2] = torch.nan # artificially ablating ground truth
loss = loss_fn(outp, gt)
torch.nan_to_num_(loss)
print('LOSS', loss)
backward_loss = loss.nanmean()
print('BACK LOSS', backward_loss.item())
backward_loss.backward()
optimizer.step()
print("after optim diff", model.weight)

Here’s the output

before optim tensor([[ 0.2478,  0.0912,  0.1117,  0.1141,  0.0293, -0.1833, -0.0429,  0.1653,
         -0.2960,  0.2571],
        [ 0.0113, -0.0377, -0.1227, -0.3150,  0.0443, -0.2880,  0.2965, -0.0512,
          0.0681,  0.0826],
        [ 0.0151,  0.1313, -0.2521,  0.0642, -0.0430,  0.1047, -0.1089,  0.2787,
          0.1329, -0.0091],
        [ 0.0209, -0.2558,  0.1979,  0.1501, -0.2211, -0.0593, -0.0961, -0.0503,
         -0.2552, -0.1482],
        [ 0.2662, -0.1144,  0.2465,  0.0744,  0.2571,  0.2164, -0.2478,  0.0052,
          0.0712, -0.1984]])
LOSS tensor([[3.2995e-01, 0.0000e+00, 4.7323e-02, 3.0494e-01, 6.6112e-02],
        [3.6434e-01, 9.5087e-04, 0.0000e+00, 1.4386e-01, 7.2434e-03],
        [2.5837e-01, 8.8023e-03, 5.0039e-04, 7.2771e-01, 6.5973e-02],
        [3.5031e-02, 1.7342e-02, 3.6359e-04, 3.2857e-01, 1.0124e-02],
        [3.4690e-02, 6.7077e-04, 2.9298e-02, 1.7946e-01, 1.2714e-01]],
       grad_fn=<NanToNumBackward0>)
BACK LOSS 0.12355112284421921
after optim diff Parameter containing:
tensor([[ 0.1478, -0.0088,  0.0117,  0.0141, -0.0707, -0.2833, -0.1429,  0.0653,
         -0.3960,  0.1571],
        [    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan],
        [    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan],
        [ 0.1209, -0.1558,  0.2979,  0.2501, -0.1211,  0.0407,  0.0039,  0.0497,
         -0.1552, -0.0482],
        [ 0.1662, -0.2144,  0.1465, -0.0256,  0.1571,  0.1164, -0.3478, -0.0948,
         -0.0288, -0.2984]], requires_grad=True)

This behaviour seems perplexing, hope I could understand better. Thanks!

I guess you would expect to see valid gradients hoping that nan_to_num would avoid creating the NaNs in the backward pass.
If so, then I think your observation is expected as the loss is calculated using the already invalid targets.
You are replacing the invalid values afterwards, but the computation graph would already contain the loss_fn call using the gt tensor with the NaNs, so I would expect to see invalid gradients.
Make sure to replace these invalid values before the loss is calculated.

1 Like