The grad of model's parameters become NaN after 2, 3 three times backward operation

Thank you for checking this topic.

I am faced with a weird error for the grad in model.weight. The details of my code and its result for each loop index is below.

The computation below can be done without any errors in the first time loop, but after the 2~6 times later, the weight of the parameters became NaN when backward computation was done.

I think the backward operation seems to be nothing wrong because of the results of the first times of the for loop. I also attached the results of gw for each index, too.

Even though the random seed for torch and numpy are fixed, the index when the error happened can be different for each time I run the code.

Please let me know if you need more info to give me your advise to solve this. Thank you in advance.

        for batch, (inputs_label, batch_img) in enumerate(loader):
            targets = torch.arange(len(inputs_label)).to(self.device)
            batch_img = batch_img.to(self.device)
            
            with torch.no_grad():
                batch_feat2 = self.m2.features(batch_img).detach().clone()
                batch_feat2 = batch_feat2.reshape(len(batch_feat2), -1)
                
                batch_feat1 = self.m1.features(batch_img).detach().clone()
                batch_feat1 = batch_feat1.reshape(len(batch_feat1), -1)
                
                torch.cuda.empty_cache()
                gc.collect()
            
            pred_feat2 = model(batch_feat1)  # model is just an simple one, only `nn.Linear()`
            
            p = torch.from_numpy(ot.unif(pred_feat2.shape[0])).float().to(self.device)
            q = torch.from_numpy(ot.unif(batch_feat2.shape[0])).float().to(self.device)
            
            pred_dist_raw = torch.cdist(pred_feat2, pred_feat2).float()
            target_dist_raw = torch.cdist(batch_feat2, batch_feat2).float()
            
            pred_dist = pred_dist_raw / pred_dist_raw.detach().clone().max()
            target_dist = target_dist_raw / target_dist_raw.max()
        
            gw = ot.gromov.gromov_wasserstein(pred_dist, target_dist, p, q, loss_fun = "square_loss", max_iter = 100)
            
            _gwd_loss = criterion(gw, targets) # criterion = nn.CrossEntropyLoss()
            
            gwd_loss += _gwd_loss.item()
            
            if model.training:
                optimizer.zero_grad()
                _gwd_loss.backward() # From here, the weights (parameters) of model become NaN not in the first index (batch = 0), but 2~6 times loop (batch = 1 ~ 5).
                optimizer.step()

The results of each index are below. All values are copied from the debugger implemented in VS Code.

# when batch = 0 after backward operation... 

gw = tensor([[4.6500e-16, 1.5237e-14, 1.4299e-14, 5.4031e-18, 3.0530e-14, 2.8312e-16,
         1.7200e-18, 1.0537e-17, 6.2500e-02, 5.1752e-18, 4.2278e-15, 6.4822e-18,
         2.7462e-16, 9.0776e-15, 2.4375e-17, 1.0768e-13],
        [2.6599e-12, 8.6926e-12, 8.0830e-11, 1.9248e-14, 5.7583e-11, 1.7127e-12,
         2.9256e-14, 1.8716e-13, 1.9202e-11, 3.0959e-13, 9.4362e-11, 5.0177e-14,
         5.6550e-13, 6.2500e-02, 2.5803e-16, 6.7969e-11],
        [2.8719e-08, 1.0451e-11, 4.0981e-08, 6.6208e-09, 7.7863e-08, 5.4568e-08,
         1.3479e-09, 3.9395e-09, 1.1923e-08, 9.2321e-10, 4.1378e-08, 3.0921e-09,
         6.2500e-02, 6.6024e-08, 1.5084e-16, 1.7985e-08],
        [2.6779e-15, 2.9902e-14, 8.4489e-14, 1.1851e-17, 1.3987e-13, 2.3564e-15,
         8.5227e-18, 6.1952e-17, 2.2882e-13, 2.1039e-17, 2.0044e-14, 1.4600e-17,
         5.4359e-16, 4.8190e-14, 3.7579e-17, 6.2500e-02],
        [4.5586e-06, 2.3048e-12, 2.7527e-06, 6.2458e-02, 4.2469e-06, 1.2654e-05,
         2.0622e-06, 9.4375e-07, 1.8982e-08, 4.0541e-07, 1.4474e-06, 1.1610e-06,
         9.7305e-06, 1.9129e-06, 3.1229e-19, 9.6479e-08],
        [1.0890e-05, 3.0172e-12, 6.2472e-06, 3.9174e-06, 1.8425e-06, 7.0800e-06,
         6.2451e-02, 3.4473e-06, 2.5641e-09, 1.9578e-06, 3.1740e-06, 2.0372e-06,
         3.1974e-06, 4.8489e-06, 2.4355e-20, 4.7416e-08],
        [3.3370e-09, 1.0569e-11, 1.8203e-08, 2.2132e-10, 3.5271e-08, 6.2500e-02,
         8.7654e-11, 3.2754e-10, 1.8930e-09, 1.2222e-10, 7.9472e-09, 1.0506e-10,
         2.0727e-09, 1.1231e-08, 4.6445e-16, 1.3265e-08],
        [6.8968e-07, 3.4354e-11, 9.2043e-07, 8.4043e-08, 4.3504e-07, 1.0305e-06,
         1.5981e-07, 6.2494e-02, 7.3798e-09, 3.4569e-08, 5.6355e-07, 9.3670e-08,
         4.2972e-07, 1.7115e-06, 3.0612e-17, 6.5541e-08],
        [1.7850e-12, 1.6286e-13, 9.4707e-11, 1.9144e-14, 6.2500e-02, 1.2449e-12,
         9.7794e-15, 2.2355e-14, 1.5326e-11, 2.3290e-14, 6.1077e-12, 2.0150e-14,
         2.0814e-13, 1.2944e-11, 1.7001e-16, 4.1699e-11],
        [2.6567e-21, 6.2497e-18, 3.7539e-19, 4.3755e-24, 2.1076e-19, 1.2207e-21,
         2.7717e-24, 8.6102e-23, 3.3591e-18, 1.5586e-23, 5.0496e-20, 9.0956e-24,
         4.7474e-22, 1.0035e-19, 6.2500e-02, 2.6176e-18],
        [7.5129e-11, 6.0201e-12, 5.8384e-10, 4.7330e-13, 4.9621e-10, 3.2638e-11,
         6.1223e-13, 1.7558e-12, 7.7076e-11, 1.5217e-12, 6.2500e-02, 1.0489e-12,
         9.2264e-12, 1.8962e-09, 2.8629e-16, 2.9138e-10],
        [3.5846e-06, 1.0650e-11, 2.2819e-06, 7.4963e-07, 2.8249e-06, 2.5170e-06,
         7.0861e-07, 6.2803e-07, 2.4074e-08, 2.1584e-07, 2.0004e-06, 6.2479e-02,
         2.5208e-06, 2.9826e-06, 2.6444e-18, 6.8370e-08],
        [6.2500e-02, 5.1804e-11, 2.3652e-08, 6.1516e-11, 2.8466e-08, 2.2380e-09,
         7.7262e-11, 1.4687e-10, 1.8918e-09, 1.6056e-10, 1.4188e-08, 1.0687e-10,
         7.0452e-10, 1.3644e-08, 9.8135e-16, 1.0878e-08],
        [8.0361e-13, 5.6072e-13, 6.2500e-02, 6.4473e-15, 5.8121e-11, 4.1949e-13,
         1.4850e-14, 2.4717e-14, 3.8580e-12, 1.6064e-14, 4.1846e-12, 9.2392e-15,
         6.6577e-14, 1.1092e-11, 2.9234e-16, 1.2776e-11],
        [1.3304e-18, 6.2500e-02, 1.0605e-16, 2.1406e-21, 2.9494e-17, 3.5467e-19,
         3.0932e-21, 3.3001e-20, 8.8549e-16, 3.6755e-20, 3.5730e-17, 5.0949e-21,
         1.5006e-19, 1.8937e-16, 6.1640e-18, 8.0487e-16],
        [5.0469e-06, 3.8415e-11, 2.0893e-06, 4.3682e-07, 1.5385e-06, 2.4231e-06,
         1.2101e-06, 3.0819e-07, 4.5271e-09, 6.2470e-02, 2.3130e-06, 3.1246e-07,
         7.5726e-07, 1.3592e-05, 5.7431e-19, 4.0942e-08]], device='cuda:3',
       grad_fn=<MulBackward0>)

model.weight.grad = tensor([[-1.2106e-10,  4.5094e-12,  3.5786e-11,  ..., -9.6593e-11,
         -3.0910e-10, -1.9421e-10],
        [-1.2350e-10,  3.8251e-12, -3.2768e-11,  ...,  3.2826e-09,
          2.3034e-09,  9.3843e-10],
        [-8.1601e-11,  6.7221e-11,  1.4069e-10,  ...,  7.4612e-10,
          4.6227e-10,  3.2016e-10],
        ...,
        [-1.1537e-10, -6.8744e-11, -1.4742e-10,  ..., -1.5278e-09,
         -1.2627e-09, -1.9038e-10],
        [ 1.8451e-11,  6.8723e-12,  7.8628e-11,  ...,  5.4209e-10,
          2.2966e-11, -4.1130e-10],
        [ 7.2494e-11, -2.1637e-11, -2.3505e-11,  ..., -1.2546e-09,
         -1.0029e-09, -2.7185e-10]], device='cuda:3')

# when batch = 1 ~ 5 after backward operation (the error happend)...

gw = tensor([[1.4175e-30, 1.2399e-37, 2.9771e-34, 1.9264e-29, 5.3040e-31, 4.6421e-32,         3.9189e-31, 4.9028e-30, 6.2500e-02, 1.2059e-34, 3.1239e-28, 2.4408e-30,         4.2906e-40, 1.1941e-29, 8.8812e-33, 3.7004e-30],        [5.0646e-35, 9.5055e-37, 2.7237e-37, 6.8329e-34, 4.5170e-28, 1.3068e-36,         7.8042e-35, 1.0985e-32, 1.0164e-32, 5.2903e-35, 5.5275e-33, 7.7107e-34,         1.5538e-40, 2.9563e-34, 7.7576e-38, 6.2500e-02],        [1.5845e-39, 1.0916e-38, 2.0858e-37, 1.7494e-34, 1.3349e-35, 4.6349e-41,         3.8512e-39, 1.5587e-35, 1.4943e-35, 6.2500e-02, 4.0319e-35, 1.7400e-38,         5.3612e-41, 1.8339e-36, 3.1585e-42, 1.7873e-34],        [2.9675e-27, 2.0752e-36, 2.0050e-28, 6.2500e-02, 3.7063e-29, 1.1508e-28,         4.1524e-27, 8.2859e-27, 5.9175e-29, 1.1651e-33, 1.7496e-27, 3.6725e-28,         5.5585e-41, 8.2999e-26, 8.8201e-29, 1.5973e-30],        [1.1002e-37, 1.7462e-37, 6.2500e-02, 9.5501e-29, 1.8937e-34, 4.1189e-40,         1.0580e-37, 2.8095e-35, 7.5286e-35, 1.5087e-37, 9.6100e-36, 2.3896e-37,         2.8306e-42, 3.0763e-35, 7.2390e-41, 4.9780e-36],        [2.1772e-14, 1.4180e-37, 1.7654e-35, 4.6700e-24, 3.3251e-24, 6.2500e-02,         5.6623e-15, 6.6706e-19, 1.2377e-25, 2.9854e-36, 5.6229e-18, 5.3399e-14,         7.8473e-43, 3.9212e-16, 1.3262e-12, 1.1872e-25],        [8.7254e-16, 1.2736e-36, 3.9274e-34, 2.5869e-23, 9.0424e-25, 2.5419e-15,         6.2500e-02, 9.0055e-20, 3.2155e-26, 1.6230e-35, 4.9749e-19, 3.3837e-16,         2.6821e-42, 1.8886e-16, 1.5022e-14, 2.3025e-25],        [1.2188e-17, 2.4896e-37, 5.8553e-35, 1.6519e-25, 9.2195e-25, 2.2789e-15,         6.4765e-17, 1.4194e-20, 4.7585e-26, 6.9260e-36, 4.5196e-20, 6.2500e-02,         6.3097e-40, 1.1947e-18, 2.4959e-16, 1.8890e-25],        [4.6892e-19, 6.2071e-36, 7.9653e-32, 2.5155e-23, 5.4780e-26, 9.2475e-20,         2.3784e-19, 2.0399e-21, 7.5880e-26, 1.0587e-32, 3.6781e-21, 1.6517e-20,         7.0332e-40, 6.2500e-02, 6.4970e-20, 1.1853e-26],        [0.0000e+00, 7.0065e-45, 2.6527e-42, 3.0801e-42, 5.7221e-41, 0.0000e+00,         0.0000e+00, 3.6686e-42, 8.6852e-42, 1.3944e-40, 1.8427e-42, 1.4013e-44,         6.2500e-02, 0.0000e+00, 0.0000e+00, 1.1349e-41],        [1.1070e-43, 6.2500e-02, 2.4309e-38, 2.2074e-39, 6.8522e-39, 0.0000e+00,         7.0065e-45, 3.8495e-39, 1.4038e-40, 2.2271e-39, 8.0879e-41, 3.6434e-44,         4.2039e-45, 4.4561e-43, 0.0000e+00, 8.3306e-38],        [5.0508e-13, 6.3545e-40, 7.6625e-37, 1.9200e-23, 2.5130e-25, 7.2813e-11,         1.4380e-12, 1.1018e-18, 1.8889e-26, 8.0530e-38, 5.3365e-18, 8.3840e-14,         0.0000e+00, 8.1971e-15, 6.2500e-02, 2.3892e-26],        [1.1774e-24, 1.8025e-33, 4.1897e-33, 1.9095e-26, 8.2338e-28, 2.3473e-25,         3.3104e-25, 6.2500e-02, 1.0450e-28, 2.5269e-33, 2.3697e-25, 5.1198e-25,         1.2642e-37, 1.8217e-24, 2.3924e-26, 6.0960e-28],        [8.5879e-23, 4.1779e-35, 6.0920e-34, 2.1292e-26, 8.6585e-27, 2.0547e-22,         1.2690e-22, 7.6655e-24, 7.3242e-26, 1.9534e-32, 6.2500e-02, 9.8456e-23,         2.2891e-38, 2.3708e-22, 6.5423e-23, 2.6051e-27],        [6.2500e-02, 4.0095e-35, 3.3414e-34, 9.3385e-24, 7.8530e-24, 3.3230e-15,         3.3943e-16, 2.2180e-19, 1.0645e-25, 5.3034e-36, 1.1850e-19, 2.7041e-17,         5.0447e-43, 1.2713e-16, 2.4376e-15, 1.1389e-25],        [3.0767e-27, 1.0605e-36, 4.6640e-35, 1.9661e-29, 6.2500e-02, 1.4273e-28,         7.3341e-28, 4.7904e-28, 5.9474e-31, 9.7950e-36, 9.5693e-28, 1.8681e-27,         8.2699e-41, 6.0447e-28, 1.6877e-29, 2.4151e-25]], device='cuda:3',       grad_fn=<MulBackward0>)


model.weight.grad = tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:3')


Have you tried reducing the learning rate?

1 Like

Thank you for your comment.

I have already confirmed that smaller learning rate (like 1e-6 - 1e-12) doesn’t change anything to solve this problem.