Thank you for checking this topic.
I am faced with a weird error for the grad
in model.weight
. The details of my code and its result for each loop index is below.
The computation below can be done without any errors in the first time loop, but after the 2~6 times later, the weight of the parameters became NaN when backward
computation was done.
I think the backward operation seems to be nothing wrong because of the results of the first times of the for loop. I also attached the results of gw
for each index, too.
Even though the random seed for torch and numpy are fixed, the index when the error happened can be different for each time I run the code.
Please let me know if you need more info to give me your advise to solve this. Thank you in advance.
for batch, (inputs_label, batch_img) in enumerate(loader):
targets = torch.arange(len(inputs_label)).to(self.device)
batch_img = batch_img.to(self.device)
with torch.no_grad():
batch_feat2 = self.m2.features(batch_img).detach().clone()
batch_feat2 = batch_feat2.reshape(len(batch_feat2), -1)
batch_feat1 = self.m1.features(batch_img).detach().clone()
batch_feat1 = batch_feat1.reshape(len(batch_feat1), -1)
torch.cuda.empty_cache()
gc.collect()
pred_feat2 = model(batch_feat1) # model is just an simple one, only `nn.Linear()`
p = torch.from_numpy(ot.unif(pred_feat2.shape[0])).float().to(self.device)
q = torch.from_numpy(ot.unif(batch_feat2.shape[0])).float().to(self.device)
pred_dist_raw = torch.cdist(pred_feat2, pred_feat2).float()
target_dist_raw = torch.cdist(batch_feat2, batch_feat2).float()
pred_dist = pred_dist_raw / pred_dist_raw.detach().clone().max()
target_dist = target_dist_raw / target_dist_raw.max()
gw = ot.gromov.gromov_wasserstein(pred_dist, target_dist, p, q, loss_fun = "square_loss", max_iter = 100)
_gwd_loss = criterion(gw, targets) # criterion = nn.CrossEntropyLoss()
gwd_loss += _gwd_loss.item()
if model.training:
optimizer.zero_grad()
_gwd_loss.backward() # From here, the weights (parameters) of model become NaN not in the first index (batch = 0), but 2~6 times loop (batch = 1 ~ 5).
optimizer.step()
The results of each index are below. All values are copied from the debugger implemented in VS Code.
# when batch = 0 after backward operation...
gw = tensor([[4.6500e-16, 1.5237e-14, 1.4299e-14, 5.4031e-18, 3.0530e-14, 2.8312e-16,
1.7200e-18, 1.0537e-17, 6.2500e-02, 5.1752e-18, 4.2278e-15, 6.4822e-18,
2.7462e-16, 9.0776e-15, 2.4375e-17, 1.0768e-13],
[2.6599e-12, 8.6926e-12, 8.0830e-11, 1.9248e-14, 5.7583e-11, 1.7127e-12,
2.9256e-14, 1.8716e-13, 1.9202e-11, 3.0959e-13, 9.4362e-11, 5.0177e-14,
5.6550e-13, 6.2500e-02, 2.5803e-16, 6.7969e-11],
[2.8719e-08, 1.0451e-11, 4.0981e-08, 6.6208e-09, 7.7863e-08, 5.4568e-08,
1.3479e-09, 3.9395e-09, 1.1923e-08, 9.2321e-10, 4.1378e-08, 3.0921e-09,
6.2500e-02, 6.6024e-08, 1.5084e-16, 1.7985e-08],
[2.6779e-15, 2.9902e-14, 8.4489e-14, 1.1851e-17, 1.3987e-13, 2.3564e-15,
8.5227e-18, 6.1952e-17, 2.2882e-13, 2.1039e-17, 2.0044e-14, 1.4600e-17,
5.4359e-16, 4.8190e-14, 3.7579e-17, 6.2500e-02],
[4.5586e-06, 2.3048e-12, 2.7527e-06, 6.2458e-02, 4.2469e-06, 1.2654e-05,
2.0622e-06, 9.4375e-07, 1.8982e-08, 4.0541e-07, 1.4474e-06, 1.1610e-06,
9.7305e-06, 1.9129e-06, 3.1229e-19, 9.6479e-08],
[1.0890e-05, 3.0172e-12, 6.2472e-06, 3.9174e-06, 1.8425e-06, 7.0800e-06,
6.2451e-02, 3.4473e-06, 2.5641e-09, 1.9578e-06, 3.1740e-06, 2.0372e-06,
3.1974e-06, 4.8489e-06, 2.4355e-20, 4.7416e-08],
[3.3370e-09, 1.0569e-11, 1.8203e-08, 2.2132e-10, 3.5271e-08, 6.2500e-02,
8.7654e-11, 3.2754e-10, 1.8930e-09, 1.2222e-10, 7.9472e-09, 1.0506e-10,
2.0727e-09, 1.1231e-08, 4.6445e-16, 1.3265e-08],
[6.8968e-07, 3.4354e-11, 9.2043e-07, 8.4043e-08, 4.3504e-07, 1.0305e-06,
1.5981e-07, 6.2494e-02, 7.3798e-09, 3.4569e-08, 5.6355e-07, 9.3670e-08,
4.2972e-07, 1.7115e-06, 3.0612e-17, 6.5541e-08],
[1.7850e-12, 1.6286e-13, 9.4707e-11, 1.9144e-14, 6.2500e-02, 1.2449e-12,
9.7794e-15, 2.2355e-14, 1.5326e-11, 2.3290e-14, 6.1077e-12, 2.0150e-14,
2.0814e-13, 1.2944e-11, 1.7001e-16, 4.1699e-11],
[2.6567e-21, 6.2497e-18, 3.7539e-19, 4.3755e-24, 2.1076e-19, 1.2207e-21,
2.7717e-24, 8.6102e-23, 3.3591e-18, 1.5586e-23, 5.0496e-20, 9.0956e-24,
4.7474e-22, 1.0035e-19, 6.2500e-02, 2.6176e-18],
[7.5129e-11, 6.0201e-12, 5.8384e-10, 4.7330e-13, 4.9621e-10, 3.2638e-11,
6.1223e-13, 1.7558e-12, 7.7076e-11, 1.5217e-12, 6.2500e-02, 1.0489e-12,
9.2264e-12, 1.8962e-09, 2.8629e-16, 2.9138e-10],
[3.5846e-06, 1.0650e-11, 2.2819e-06, 7.4963e-07, 2.8249e-06, 2.5170e-06,
7.0861e-07, 6.2803e-07, 2.4074e-08, 2.1584e-07, 2.0004e-06, 6.2479e-02,
2.5208e-06, 2.9826e-06, 2.6444e-18, 6.8370e-08],
[6.2500e-02, 5.1804e-11, 2.3652e-08, 6.1516e-11, 2.8466e-08, 2.2380e-09,
7.7262e-11, 1.4687e-10, 1.8918e-09, 1.6056e-10, 1.4188e-08, 1.0687e-10,
7.0452e-10, 1.3644e-08, 9.8135e-16, 1.0878e-08],
[8.0361e-13, 5.6072e-13, 6.2500e-02, 6.4473e-15, 5.8121e-11, 4.1949e-13,
1.4850e-14, 2.4717e-14, 3.8580e-12, 1.6064e-14, 4.1846e-12, 9.2392e-15,
6.6577e-14, 1.1092e-11, 2.9234e-16, 1.2776e-11],
[1.3304e-18, 6.2500e-02, 1.0605e-16, 2.1406e-21, 2.9494e-17, 3.5467e-19,
3.0932e-21, 3.3001e-20, 8.8549e-16, 3.6755e-20, 3.5730e-17, 5.0949e-21,
1.5006e-19, 1.8937e-16, 6.1640e-18, 8.0487e-16],
[5.0469e-06, 3.8415e-11, 2.0893e-06, 4.3682e-07, 1.5385e-06, 2.4231e-06,
1.2101e-06, 3.0819e-07, 4.5271e-09, 6.2470e-02, 2.3130e-06, 3.1246e-07,
7.5726e-07, 1.3592e-05, 5.7431e-19, 4.0942e-08]], device='cuda:3',
grad_fn=<MulBackward0>)
model.weight.grad = tensor([[-1.2106e-10, 4.5094e-12, 3.5786e-11, ..., -9.6593e-11,
-3.0910e-10, -1.9421e-10],
[-1.2350e-10, 3.8251e-12, -3.2768e-11, ..., 3.2826e-09,
2.3034e-09, 9.3843e-10],
[-8.1601e-11, 6.7221e-11, 1.4069e-10, ..., 7.4612e-10,
4.6227e-10, 3.2016e-10],
...,
[-1.1537e-10, -6.8744e-11, -1.4742e-10, ..., -1.5278e-09,
-1.2627e-09, -1.9038e-10],
[ 1.8451e-11, 6.8723e-12, 7.8628e-11, ..., 5.4209e-10,
2.2966e-11, -4.1130e-10],
[ 7.2494e-11, -2.1637e-11, -2.3505e-11, ..., -1.2546e-09,
-1.0029e-09, -2.7185e-10]], device='cuda:3')
# when batch = 1 ~ 5 after backward operation (the error happend)...
gw = tensor([[1.4175e-30, 1.2399e-37, 2.9771e-34, 1.9264e-29, 5.3040e-31, 4.6421e-32, 3.9189e-31, 4.9028e-30, 6.2500e-02, 1.2059e-34, 3.1239e-28, 2.4408e-30, 4.2906e-40, 1.1941e-29, 8.8812e-33, 3.7004e-30], [5.0646e-35, 9.5055e-37, 2.7237e-37, 6.8329e-34, 4.5170e-28, 1.3068e-36, 7.8042e-35, 1.0985e-32, 1.0164e-32, 5.2903e-35, 5.5275e-33, 7.7107e-34, 1.5538e-40, 2.9563e-34, 7.7576e-38, 6.2500e-02], [1.5845e-39, 1.0916e-38, 2.0858e-37, 1.7494e-34, 1.3349e-35, 4.6349e-41, 3.8512e-39, 1.5587e-35, 1.4943e-35, 6.2500e-02, 4.0319e-35, 1.7400e-38, 5.3612e-41, 1.8339e-36, 3.1585e-42, 1.7873e-34], [2.9675e-27, 2.0752e-36, 2.0050e-28, 6.2500e-02, 3.7063e-29, 1.1508e-28, 4.1524e-27, 8.2859e-27, 5.9175e-29, 1.1651e-33, 1.7496e-27, 3.6725e-28, 5.5585e-41, 8.2999e-26, 8.8201e-29, 1.5973e-30], [1.1002e-37, 1.7462e-37, 6.2500e-02, 9.5501e-29, 1.8937e-34, 4.1189e-40, 1.0580e-37, 2.8095e-35, 7.5286e-35, 1.5087e-37, 9.6100e-36, 2.3896e-37, 2.8306e-42, 3.0763e-35, 7.2390e-41, 4.9780e-36], [2.1772e-14, 1.4180e-37, 1.7654e-35, 4.6700e-24, 3.3251e-24, 6.2500e-02, 5.6623e-15, 6.6706e-19, 1.2377e-25, 2.9854e-36, 5.6229e-18, 5.3399e-14, 7.8473e-43, 3.9212e-16, 1.3262e-12, 1.1872e-25], [8.7254e-16, 1.2736e-36, 3.9274e-34, 2.5869e-23, 9.0424e-25, 2.5419e-15, 6.2500e-02, 9.0055e-20, 3.2155e-26, 1.6230e-35, 4.9749e-19, 3.3837e-16, 2.6821e-42, 1.8886e-16, 1.5022e-14, 2.3025e-25], [1.2188e-17, 2.4896e-37, 5.8553e-35, 1.6519e-25, 9.2195e-25, 2.2789e-15, 6.4765e-17, 1.4194e-20, 4.7585e-26, 6.9260e-36, 4.5196e-20, 6.2500e-02, 6.3097e-40, 1.1947e-18, 2.4959e-16, 1.8890e-25], [4.6892e-19, 6.2071e-36, 7.9653e-32, 2.5155e-23, 5.4780e-26, 9.2475e-20, 2.3784e-19, 2.0399e-21, 7.5880e-26, 1.0587e-32, 3.6781e-21, 1.6517e-20, 7.0332e-40, 6.2500e-02, 6.4970e-20, 1.1853e-26], [0.0000e+00, 7.0065e-45, 2.6527e-42, 3.0801e-42, 5.7221e-41, 0.0000e+00, 0.0000e+00, 3.6686e-42, 8.6852e-42, 1.3944e-40, 1.8427e-42, 1.4013e-44, 6.2500e-02, 0.0000e+00, 0.0000e+00, 1.1349e-41], [1.1070e-43, 6.2500e-02, 2.4309e-38, 2.2074e-39, 6.8522e-39, 0.0000e+00, 7.0065e-45, 3.8495e-39, 1.4038e-40, 2.2271e-39, 8.0879e-41, 3.6434e-44, 4.2039e-45, 4.4561e-43, 0.0000e+00, 8.3306e-38], [5.0508e-13, 6.3545e-40, 7.6625e-37, 1.9200e-23, 2.5130e-25, 7.2813e-11, 1.4380e-12, 1.1018e-18, 1.8889e-26, 8.0530e-38, 5.3365e-18, 8.3840e-14, 0.0000e+00, 8.1971e-15, 6.2500e-02, 2.3892e-26], [1.1774e-24, 1.8025e-33, 4.1897e-33, 1.9095e-26, 8.2338e-28, 2.3473e-25, 3.3104e-25, 6.2500e-02, 1.0450e-28, 2.5269e-33, 2.3697e-25, 5.1198e-25, 1.2642e-37, 1.8217e-24, 2.3924e-26, 6.0960e-28], [8.5879e-23, 4.1779e-35, 6.0920e-34, 2.1292e-26, 8.6585e-27, 2.0547e-22, 1.2690e-22, 7.6655e-24, 7.3242e-26, 1.9534e-32, 6.2500e-02, 9.8456e-23, 2.2891e-38, 2.3708e-22, 6.5423e-23, 2.6051e-27], [6.2500e-02, 4.0095e-35, 3.3414e-34, 9.3385e-24, 7.8530e-24, 3.3230e-15, 3.3943e-16, 2.2180e-19, 1.0645e-25, 5.3034e-36, 1.1850e-19, 2.7041e-17, 5.0447e-43, 1.2713e-16, 2.4376e-15, 1.1389e-25], [3.0767e-27, 1.0605e-36, 4.6640e-35, 1.9661e-29, 6.2500e-02, 1.4273e-28, 7.3341e-28, 4.7904e-28, 5.9474e-31, 9.7950e-36, 9.5693e-28, 1.8681e-27, 8.2699e-41, 6.0447e-28, 1.6877e-29, 2.4151e-25]], device='cuda:3', grad_fn=<MulBackward0>)
model.weight.grad = tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], device='cuda:3')