Problem of grad becoming None

Hello,
Here is where I encounter the problem, basically I tried to use torch.autograd.grad to compute gradient but it returned None.

class BatchNorm2dMul(nn.Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super(BatchNorm2dMul, self).__init__()
        self.bn = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=False, track_running_stats=track_running_stats)
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.affine = affine

    def forward(self, x):
        bn_out = self.bn(x)
        if self.affine:
            out = self.gamma[None, :, None, None] * bn_out + self.beta[None, :, None, None]
        return out, bn_out


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_p):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = BatchNorm2dMul(num_features=out_channels)
        self.lReLu = nn.LeakyReLU()
        self.dp = nn.Dropout(dropout_p)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = BatchNorm2dMul(num_features=out_channels)

    def forward(self, x):
        bn_outputs = []
        out, bn_out = self.bn1(self.conv1(x))
        bn_outputs.append(bn_out)
        out = self.lReLu(out)
        out = self.dp(out)
        out, bn_out = self.bn2(self.conv2(out))
        bn_outputs.append(bn_out)
        out = self.lReLu(out)    
        return out, bn_outputs


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_p):
        self.maxpool = nn.MaxPool2d(2)
        self.convBlock = ConvBlock(in_channels, out_channels, dropout_p)

    def forward(self, x):
        x = self.maxpool(x)
        bn_outputs = []
        out, bn_output = self.convBlock(x)
        bn_outputs.extend(bn_output)
        return out, bn_outputs

class UpBlock(nn.Module):
    def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
                 bilinear=True):
        super(UpBlock, self).__init__()
        self.bilinear = bilinear
        if bilinear:
            self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
            self.up = nn.Upsample(
                scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(
                in_channels1, in_channels2, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)

    def forward(self, x1, x2):
        if self.bilinear:
            x1 = self.conv1x1(x1)
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        bn_outputs = []
        out, bn_output = self.conv(x)
        bn_outputs.extend(bn_output)
        return out, bn_outputs


class Encoder(nn.Module):
    def __init__(self, params):
        super(Encoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        self.bilinear = self.params['bilinear']
        self.dropout = self.params['dropout']
        assert (len(self.ft_chns) == 5)
        self.in_conv = ConvBlock(
            self.in_chns, self.ft_chns[0], self.dropout[0])
        self.down1 = DownBlock(
            self.ft_chns[0], self.ft_chns[1], self.dropout[1])
        self.down2 = DownBlock(
            self.ft_chns[1], self.ft_chns[2], self.dropout[2])
        self.down3 = DownBlock(
            self.ft_chns[2], self.ft_chns[3], self.dropout[3])
        self.down4 = DownBlock(
            self.ft_chns[3], self.ft_chns[4], self.dropout[4])

    def forward(self, x):
        all_bn_outputs = []
        x0, bn_outputs0 = self.in_conv(x)
        all_bn_outputs.extend(bn_outputs0)
        x1, bn_outputs1 = self.down1(x0)
        all_bn_outputs.extend(bn_outputs1)
        x2, bn_outputs2 = self.down2(x1)
        all_bn_outputs.extend(bn_outputs2)
        x3, bn_outputs3 = self.down3(x2)
        all_bn_outputs.extend(bn_outputs3)
        x4, bn_outputs4 = self.down4(x3)
        all_bn_outputs.extend(bn_outputs4)
        return [x0, x1, x2, x3, x4], all_bn_outputs

class Decoder(nn.Module):
    def __init__(self, params):
        super(Decoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        self.bilinear = self.params['bilinear']
        assert (len(self.ft_chns) == 5)

        self.up1 = UpBlock(
            self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
        self.up2 = UpBlock(
            self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
        self.up3 = UpBlock(
            self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
        self.up4 = UpBlock(
            self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
                                  kernel_size=3, padding=1)

    def forward(self, feature, all_bn_outputs):
        x0 = feature[0]
        x1 = feature[1]
        x2 = feature[2]
        x3 = feature[3]
        x4 = feature[4]
        feature_map = [x4]
        x, output = self.up1(x4, x3)
        all_bn_outputs.extend(output)
        feature_map.append(x)
        x, output = self.up2(x, x2)
        all_bn_outputs.extend(output)
        feature_map.append(x)
        x, output = self.up3(x, x1)
        all_bn_outputs.extend(output)
        feature_map.append(x)
        x, output = self.up4(x, x0)
        all_bn_outputs.extend(output)
        feature_map.append(x)
        output = self.out_conv(x)
        return output, feature_map, all_bn_outputs


class UNet(nn.Module):
    def __init__(self, in_chns, class_num, train_encoder=True, train_decoder=True, unfreeze_seg=True):
        super(UNet, self).__init__()

        params = {'in_chns': in_chns,
                  'feature_chns': [16, 32, 64, 128, 256],
                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
                  'class_num': class_num,
                  'bilinear': False,
                  'acti_func': 'relu'}

        self.encoder = Encoder(params)
        self.decoder = Decoder(params)
        self.train_encoder = train_encoder
        self.train_decoder = train_decoder

        if not (train_encoder):
            for params in self.encoder.parameters():
                params.requires_grad = False
                params = params.detach_()

        if not (train_decoder):
            for params in self.decoder.parameters():
                params.requires_grad = False
                params = params.detach_()
                
        if not(unfreeze_seg):
            for params in self.encoder.parameters():
                params.requires_grad = False
                params = params.detach_()
            for params in self.decoder.parameters():
                if params not in self.decoder.out_conv.parameters():
                    params.requires_grad = False
                    params = params.detach_()

    def forward(self, x):
        feature, all_bn_outputs = self.encoder(x)
        output, feature_map, all_bn_outputs = self.decoder(feature, all_bn_outputs)
        return output, feature[-1], feature_map, all_bn_outputs

basically I am adding a BatchNorm2dMul class to substitude the normal batch_normal layer, and when I use this class to build a Unet model and try to find the grad with respect to all_bn_outputs as shown below:

model = UNet(in_chns=1, class_num=4, \
            train_encoder=True, train_decoder=True, unfreeze_seg=True).cuda()
pred_l, _, _, all_bn_outputs= model(train_l_data)
loss_ce = CrossEntropyLoss()(pred_l, train_l_label.long())
loss_grads = torch.autograd.grad(outputs=loss_ce, inputs=all_bn_outputs, create_graph=True, allow_unused=True)
for grd in loss_grads:
    print(grd)

and the result is 18 None, but as I consider, the all_bn_outputs is related to loss_ce , so I am confused here. Please take some time to look, thank you for your help!

By the way, my torch version is 1.10.2

Can you check the answer in the below post and try using list.append() instead of list.extend()?

Hello Arul,
Yes it works. Besides I have another question. It seems that when I am using DataParallel model, it doesn’t work, but as I am training the model in a single gpu, it works fine. Could you please offer some suggestions? Thank you.

I do not have a concrete solution. Though I am able to reproduce the behavior you are observing.
Possibly, you are hitting an issue similar to:

DistributedDataParallel has a bottleneck as follows:

This module doesn’t work with torch.autograd.grad() (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters).

this could be the case with DataParallel too.

Yes, I think I may still use single gpu, and that works fine.
Thank you for your help!