Model param.grad is None, how to debug?

No, please try to narrow down the needed code snippet, which I guess would be the model, criterion, optimizer, and the shapes for the random input tensors.
Reading through en entire project and trying to get it running would unfortunately take more time than debugging the actual issues in most cases. :confused:

Hi, @ptrblck I have clean all the code into a single file. By the way, during my cleaning process, I have tested the DataParallel training mode, and it would not cause loss error. So I think the problem should locate in the DistributedDataParallel setting. And I then tested it many times until I turned off the automix precision training and the error disappeared. Maybe current task is too hard for my model.

Anyway, thanks for your kindness!

Automatic mixed-precision training should not break your training and cause None gradients, so I would still be interested in the minimal code snippet in case you could post it.

OK, I would show my debug experience to figure out the problem.

The initial problem is our project training loss goes to NaN after 2 steps, so I guess the gradient maybe have something incorrect number, like NaN. Then, I print all the layer’s gradient and find out the final LayerNorm’s gradient is NoneType. This phenomena makes me believe that the problem locates here. But I overlooked the final LayerNorm had been disabled! It is correct that its gradient is None. After that, I checked other code which may cause the loss NaN, and found the real problem is Automatic mixed-precision training. Here is my network:

Summary
class ConvBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
        super().__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        return out


class UNet(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super().__init__()

        # # Sinusoidal embedding
        # self.time_embed = nn.Embedding(n_steps, time_emb_dim) ##number of embedding=n_steps, embedding dim=time_emb_dim
        # self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        # self.time_embed.requires_grad_(False)
        self.time_embed = sinusoidal_embedding(n_steps, time_emb_dim).to(torch.float32)


        # Down pass
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            ConvBlock((1, 28, 28), 1, 10),
            ConvBlock((10, 28, 28), 10, 10),
            ConvBlock((10, 28, 28), 10, 10)
        )
        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)  #down sampling 

        self.te2 = self._make_te(time_emb_dim, 10)
        self.b2 = nn.Sequential(
            ConvBlock((10, 14, 14), 10, 20),
            ConvBlock((20, 14, 14), 20, 20),
            ConvBlock((20, 14, 14), 20, 20)
        )
        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)

        self.te3 = self._make_te(time_emb_dim, 20)
        self.b3 = nn.Sequential(
            ConvBlock((20, 7, 7), 20, 40),
            ConvBlock((40, 7, 7), 40, 40),
            ConvBlock((40, 7, 7), 40, 40)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1),
            nn.SiLU(),
            nn.Conv2d(40, 40, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 40)
        self.b_mid = nn.Sequential(
            ConvBlock((40, 3, 3), 40, 20),
            ConvBlock((20, 3, 3), 20, 20),
            ConvBlock((20, 3, 3), 20, 40)
        )

        # up pass
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1)
        )

        self.te4 = self._make_te(time_emb_dim, 80)
        self.b4 = nn.Sequential(
            ConvBlock((80, 7, 7), 80, 40),
            ConvBlock((40, 7, 7), 40, 20),
            ConvBlock((20, 7, 7), 20, 20)
        )

        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
        self.te5 = self._make_te(time_emb_dim, 40)
        self.b5 = nn.Sequential(
            ConvBlock((40, 14, 14), 40, 20),
            ConvBlock((20, 14, 14), 20, 10),
            ConvBlock((10, 14, 14), 10, 10)
        )

        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.te_out = self._make_te(time_emb_dim, 20)
        self.b_out = nn.Sequential(
            ConvBlock((20, 28, 28), 20, 10),
            ConvBlock((10, 28, 28), 10, 10),
            ConvBlock((10, 28, 28), 10, 10, normalize=False)
        )

        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x, t):
        # x is (N, 1, 28, 28) (image with positional embedding stacked on channel dimension)
        t = F.embedding(t.to(torch.long), self.time_embed.to(x.device))#self.time_embed(t)  #N * 100, ranint from (0, n_steps)
        n = x.shape[0] #n = len(x)
        # t = self.time_embed(t)  #N * 100, ranint from (0, n_steps)
        # n = len(x)
        ##第一个阶段, channel*2, B/2, W/2
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)  (N, 2, 28, 28) + (N, 1, 1, 1)
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (N, 20, 14, 14)  (N, 10, 14, 14) + (N, 10, 1, 1)
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (N, 40, 7, 7)  (N, 20, 7, 7) + (N, 20, 1, 1)

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3) 
        ##第一个阶段结束
        ##第二个阶段channel/2, B*2, W*2. 但是需要与down pass的中间结果concat到一起.
        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)

        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)

        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)
        ##第二个阶段最后一步时要转化为N*1*28*28
        out = self.conv_out(out)
        ##第二个阶段结束

        return out ##ot pus the score

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )

In conclusion, that’s my mistake for the above misleading problem discription. I’m sorry about that.

Thanks for your kindness.

@wml1993 你解决了吗?这应该是李宏毅老师作业的代码,我也遇到同样的问题了。

From Google Translate (added by moderator):

Have you solved it? This should be the code of teacher Li Hongyi’s homework, and I have encountered the same problem.

Hello, @ptrblck. I am training a 3D UNet model to perform lung lobe segmentation. My inputs are [256, 256, 200] volumes of diagnostic chest CT scans and the labels are the ground truth lung lobe segmentation. I trained a model using Dice and Focal loss function and now I would like to use topological loss to correct the errors. I wrote my own topological loss function: I am using the Cubical Ripser package to compute the barcode diagram of my probability segmentation map. The probability segmentation map is the output of my model after applying softmax activation and has the following dimension of [1, 6, 200, 256, 256], 6 being the number of classes+background. This is my loss function code:

import torch
import numpy as np
import cripser
import math
import gudhi as gd
import matplotlib.pyplot as plt


def diag_tidy(diag, eps=4e-1):
    new_diag = []
    dimensions = [0, 1, 2]

    for dim in dimensions:
        dim_indices = np.where(diag[:, 0] == dim)[0]  # Indices for the current dimension
        dim_entries = diag[dim_indices]

        sorted_indices = np.argsort(np.abs(dim_entries[:, 2] - dim_entries[:, 1]))[::-1]  # Sort indices in descending order of lifespan
        sorted_entries = dim_entries[sorted_indices]
        for entry in sorted_entries:
            birth = entry[2]
            death = entry[1]
            # Modify values if they are greater than 1 or less than 0
            if birth > 1:
                birth = 1
            if death < 0:
                death = 0
            if np.abs(death - birth) > eps:
                new_diag.append((entry[0].astype('int'), (death, birth)))

    return new_diag



B_One_Comp = {0:1, 1:0, 2:0} # RUL(1),RML(2), RLL(3), LUL(4), LLL(5), RUL U RML, RUL U RLL, RML U RLL, RLL U LUL
B_Two_Comp = {0:2, 1:0, 2:0} # RUL U LUL, RUL UU LLL, RML U LUL, RML U LLL, RLL U LUL, RLL U LLL

# Correct topology for each lobe and their pairs
B_dict = {11:B_One_Comp,
          22:B_One_Comp,
          33:B_One_Comp,
          44:B_One_Comp,
          55:B_One_Comp,
          12:B_One_Comp,
          13:B_One_Comp,
          23:B_One_Comp,
          14:B_Two_Comp,
          15:B_Two_Comp,
          24:B_Two_Comp,
          25:B_Two_Comp,
          34:B_Two_Comp,
          35:B_Two_Comp
          }

def topoloss_val(model_output, lambd, requires_grad=True):
    out_softmax = torch.softmax(model_output, dim=1)
    Z_cpu = out_softmax
    class_1 = Z_cpu[:, 1, :, :, :]
    class_2 = Z_cpu[:, 2, :, :, :]
    class_3 = Z_cpu[:, 3, :, :, :]
    class_4 = Z_cpu[:, 4, :, :, :]
    class_5 = Z_cpu[:, 5, :, :, :]
    # plt.imshow(class_3[0,90].detach().numpy(), cmap='gray')
    # plt.show()
    classes_dict = {1: class_1,
                    2: class_2,
                    3: class_3,
                    4: class_4,
                    5: class_5}

    L_topo = 0
    for i in list(B_dict.keys()):

        classes = [int(digit) for digit in str(i)]
        max_prob = torch.max(torch.stack((classes_dict[classes[0]],classes_dict[classes[1]]), dim=1), dim=1)[0]
        max_probe_final = max_prob.squeeze(0)
        # print(max_prob.shape)

        with torch.no_grad():
            max_probe_topo = max_probe_final.cpu().numpy()
        ## compute persistence of the sublevel filtration
            pd = cripser.computePH(1 - max_probe_topo, maxdim=2, location="birth")
            diag_clean = diag_tidy(pd, eps=1e-2)

        # if i==11:
        #     gd.plot_persistence_barcode(diag_clean)
        #     plt.ylim(-1, len(diag_clean))
        #     plt.xticks(ticks=np.linspace(0, 1, 6), labels=np.round(np.linspace(1, 0, 6), 2))
        #     plt.yticks([])
        #     plt.show()

        new_diag_dim0 = []
        new_diag_dim1 = []
        new_diag_dim2 = []
        for dim, intervals in diag_clean:
            if dim == 0:
                new_diag_dim0.append(intervals)
            elif dim == 1:
                new_diag_dim1.append(intervals)
            else:
                new_diag_dim2.append(intervals)

        interval_dim0 = torch.tensor([
            abs(intervals[0] - intervals[1]) if not math.isinf(intervals[1]) else abs(intervals[0] - 1) for
            intervals in new_diag_dim0], dtype=torch.float32)

        interval_dim1 = torch.tensor([
            abs(intervals[0] - intervals[1]) if not math.isinf(intervals[1]) else abs(intervals[0] - 1) for
            intervals in new_diag_dim1], dtype=torch.float32)

        interval_dim2 = torch.tensor([
            abs(intervals[0] - intervals[1]) if not math.isinf(intervals[1]) else abs(intervals[0] - 1) for
            intervals in new_diag_dim2], dtype=torch.float32)

        # print(f'dim2:{interval_dim2}')
        #     dimension 0

        bar_signs_dim0 = torch.ones(len(new_diag_dim0), dtype=torch.float32)
        bar_signs_dim0[:B_dict[i][0]] =-1
        # bar_signs_dim0 = bar_signs_dim0.detach()
        L0 = B_dict[i][0] + torch.sum(interval_dim0 * bar_signs_dim0)
        # print(f'L0:{L0}')

        bar_signs_dim1 = torch.ones(len(new_diag_dim1), dtype=torch.float32)
        bar_signs_dim1[:B_dict[i][1]] = -1
        L1 = B_dict[i][1] + torch.sum(interval_dim1 * bar_signs_dim1)
        # print(f'L1:{L1}')

        bar_signs_dim2 = torch.ones(len(new_diag_dim2), dtype=torch.float32)
        bar_signs_dim2[:B_dict[i][2]] = -1
        L2 = B_dict[i][2] + torch.sum(interval_dim2 * bar_signs_dim2)
        # print(f'L2:{L2}')

        L_total_class = L0 + L1 + L2
        L_topo = L_topo + L_total_class

    topo_loss = L_topo * lambd
    topo_loss.requires_grad_()
    return topo_loss

and I am using it like:

for i, (x, y) in batch_iter:
    input, target = x.to(self.device), y.to(self.device) 
    self.optimizer.zero_grad()  # zerograd the parameters
    out = self.model(input)

    topo_val = topoloss_val(out, 0.05, requires_grad=True)
    # topo_val = topo_val.to(self.device)
    topo_lass_value1 = topo_val.item()

    total_loss = topo_lass_value1
    tl = topo_val
    tl.backward()
    train_losses.append(total_loss)
    self.optimizer.step()

but the model is not training and the param.grad is None. I fixed it wherever I was using detach(). Do you think the problem could be with the CubicalRipser library, although I’m using torch.no_grad() ??

You are breaking the computation graph in a few places:

  • using torch.no_grad() will disallow Autograd to track these operations
  • converting a tensor to a numpy array breaks the computation graph, as Autograd cannot track operations from 3rd party libs
  • recreating a new tensor via torch.tensor(old_tensor) will create a new leaf-tensor without a gradient history
  • calling .requires_grad_() on the output tensor won’t fix these issues.

Thank you for the reply, @ptrblck .

Actually, I tried the code without the torch.no_grad() as well. the thing is that the ComputePH method of the CubicalRipser library only takes numpy arrays as input, so I don’t know how I can get through that. So is the computational graph only computed for the loss? Do you have any suggestion how I can convert the output of the model to numpy array and do the PH calculation without breaking the computation graph?

You won’t be able to directly use numpy arrays in any of your computation without breaking the graph.
You would thus need to either rewrite the operations using plain PyTorch operations or you would need to implement a custom autograd.Function including the backward method as described here.

1 Like

@ptrblck please look into this issue