Hi all, I noticed something perplexing in my experimentation with Pytorch. I work on time series data where I don’t always have the ground truth for all the time steps for a sequence. I’m having trouble doing gradient updates. Here’s the code
import torch
rand_in = torch.rand((5,10))
loss_fn = torch.nn.MSELoss(reduction='none')
model = torch.nn.Linear(in_features=10,out_features=5).requires_grad_(True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
temp_weight = model.weight.data.clone()
print("before optim", temp_weight)
outp = model(rand_in)
optimizer.zero_grad()
gt = torch.zeros_like(outp)
gt[0,1] = torch.nan # artificially ablating ground truth
gt[1,2] = torch.nan # artificially ablating ground truth
loss = loss_fn(outp, gt)
torch.nan_to_num_(loss)
print('LOSS', loss)
backward_loss = loss.nanmean()
print('BACK LOSS', backward_loss.item())
backward_loss.backward()
optimizer.step()
print("after optim diff", model.weight)
Here’s the output
before optim tensor([[ 0.2478, 0.0912, 0.1117, 0.1141, 0.0293, -0.1833, -0.0429, 0.1653,
-0.2960, 0.2571],
[ 0.0113, -0.0377, -0.1227, -0.3150, 0.0443, -0.2880, 0.2965, -0.0512,
0.0681, 0.0826],
[ 0.0151, 0.1313, -0.2521, 0.0642, -0.0430, 0.1047, -0.1089, 0.2787,
0.1329, -0.0091],
[ 0.0209, -0.2558, 0.1979, 0.1501, -0.2211, -0.0593, -0.0961, -0.0503,
-0.2552, -0.1482],
[ 0.2662, -0.1144, 0.2465, 0.0744, 0.2571, 0.2164, -0.2478, 0.0052,
0.0712, -0.1984]])
LOSS tensor([[3.2995e-01, 0.0000e+00, 4.7323e-02, 3.0494e-01, 6.6112e-02],
[3.6434e-01, 9.5087e-04, 0.0000e+00, 1.4386e-01, 7.2434e-03],
[2.5837e-01, 8.8023e-03, 5.0039e-04, 7.2771e-01, 6.5973e-02],
[3.5031e-02, 1.7342e-02, 3.6359e-04, 3.2857e-01, 1.0124e-02],
[3.4690e-02, 6.7077e-04, 2.9298e-02, 1.7946e-01, 1.2714e-01]],
grad_fn=<NanToNumBackward0>)
BACK LOSS 0.12355112284421921
after optim diff Parameter containing:
tensor([[ 0.1478, -0.0088, 0.0117, 0.0141, -0.0707, -0.2833, -0.1429, 0.0653,
-0.3960, 0.1571],
[ nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan],
[ nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan],
[ 0.1209, -0.1558, 0.2979, 0.2501, -0.1211, 0.0407, 0.0039, 0.0497,
-0.1552, -0.0482],
[ 0.1662, -0.2144, 0.1465, -0.0256, 0.1571, 0.1164, -0.3478, -0.0948,
-0.0288, -0.2984]], requires_grad=True)
This behaviour seems perplexing, hope I could understand better. Thanks!