Gradient tensor exists for binary output, but is None for multiclass output

I am encountering some very strange behavior with my model. I am trying to take the feature extraction layers of ResNet34 and add a final classifier layer onto the end. When I change the final nn.Linear layer from a binary output to a multiclass output, the gradient disappears and my model fails to update its weights.

PyTorch version 1.0.1
Let me know if any additional information is useful, I’ve never made a post like this before.

Example code:

General setup

import torch
import torch.nn as nn
import torchvision.models as models

class ResNet34_Model(nn.Module):
    def __init__(self, original_model, num_classes):
        super(TestModel, self).__init__()
            linear_size = 512
            self.features = nn.Sequential(*list(original_model.children())[:-1])
            self.classifier = nn.Sequential(
                nn.Linear(linear_size, num_classes)
            )
        # Freeze those weights
        for p in self.features.parameters():
            p.requires_grad = False

    def forward(self, x):
        f = self.features(x)
        if self.modelName == 'alexnet' :
            f = f.view(f.size(0), 256 * 6 * 6)
        elif self.modelName == 'vgg16':
            f = f.view(f.size(0), -1)
        elif self.modelName == 'resnet' :
            f = f.view(f.size(0), -1)
        elif self.modelName == "densenet":
            # f = f.relu(f, inplace=True)
            # f = f.avg_pool2d(f, kernel_size=7).view(f.size(0), -1)
            f = f.view(f.size(0), -1)
        y = self.classifier(f)
        return y

Binary code (gradient is fine)

original_model = models.__dict__["resnet34"](pretrained=True)
model = ResNet34_Model(original_model, 1)
model = model.cuda()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = 3e-4, weight_decay=0)
criterion = nn.BCEWithLogitsLoss().cuda()

X_batch = X_batch.cuda() # X_batch and y_batch are pulled from the first iteration of a data loader
y_batch = y_batch.cuda()

# train for 1 epoch
model.train()
output = model(X_batch).squeeze()
loss = criterion(output.float(), y_batch.float())
output_var = torch.sigmoid(output).data
output_class = [ 0 if a<=0.5 else 1  for a in output_var]
optimizer.zero_grad()
loss.backward()
optimizer.step()

# check that gradients exist
for n, p in model.named_parameters():
    print(n)
    if(p.requires_grad) and ("bias" not in n):
        print(p.grad)

Output

features.0.weight
features.1.weight
features.1.bias
features.4.0.conv1.weight
features.4.0.bn1.weight
features.4.0.bn1.bias
features.4.0.conv2.weight
features.4.0.bn2.weight
features.4.0.bn2.bias
features.4.1.conv1.weight
features.4.1.bn1.weight
features.4.1.bn1.bias
features.4.1.conv2.weight
features.4.1.bn2.weight
features.4.1.bn2.bias
features.4.2.conv1.weight
features.4.2.bn1.weight
features.4.2.bn1.bias
features.4.2.conv2.weight
features.4.2.bn2.weight
features.4.2.bn2.bias
features.5.0.conv1.weight
features.5.0.bn1.weight
features.5.0.bn1.bias
features.5.0.conv2.weight
features.5.0.bn2.weight
features.5.0.bn2.bias
features.5.0.downsample.0.weight
features.5.0.downsample.1.weight
features.5.0.downsample.1.bias
features.5.1.conv1.weight
features.5.1.bn1.weight
features.5.1.bn1.bias
features.5.1.conv2.weight
features.5.1.bn2.weight
features.5.1.bn2.bias
features.5.2.conv1.weight
features.5.2.bn1.weight
features.5.2.bn1.bias
features.5.2.conv2.weight
features.5.2.bn2.weight
features.5.2.bn2.bias
features.5.3.conv1.weight
features.5.3.bn1.weight
features.5.3.bn1.bias
features.5.3.conv2.weight
features.5.3.bn2.weight
features.5.3.bn2.bias
features.6.0.conv1.weight
features.6.0.bn1.weight
features.6.0.bn1.bias
features.6.0.conv2.weight
features.6.0.bn2.weight
features.6.0.bn2.bias
features.6.0.downsample.0.weight
features.6.0.downsample.1.weight
features.6.0.downsample.1.bias
features.6.1.conv1.weight
features.6.1.bn1.weight
features.6.1.bn1.bias
features.6.1.conv2.weight
features.6.1.bn2.weight
features.6.1.bn2.bias
features.6.2.conv1.weight
features.6.2.bn1.weight
features.6.2.bn1.bias
features.6.2.conv2.weight
features.6.2.bn2.weight
features.6.2.bn2.bias
features.6.3.conv1.weight
features.6.3.bn1.weight
features.6.3.bn1.bias
features.6.3.conv2.weight
features.6.3.bn2.weight
features.6.3.bn2.bias
features.6.4.conv1.weight
features.6.4.bn1.weight
features.6.4.bn1.bias
features.6.4.conv2.weight
features.6.4.bn2.weight
features.6.4.bn2.bias
features.6.5.conv1.weight
features.6.5.bn1.weight
features.6.5.bn1.bias
features.6.5.conv2.weight
features.6.5.bn2.weight
features.6.5.bn2.bias
features.7.0.conv1.weight
features.7.0.bn1.weight
features.7.0.bn1.bias
features.7.0.conv2.weight
features.7.0.bn2.weight
features.7.0.bn2.bias
features.7.0.downsample.0.weight
features.7.0.downsample.1.weight
features.7.0.downsample.1.bias
features.7.1.conv1.weight
features.7.1.bn1.weight
features.7.1.bn1.bias
features.7.1.conv2.weight
features.7.1.bn2.weight
features.7.1.bn2.bias
features.7.2.conv1.weight
features.7.2.bn1.weight
features.7.2.bn1.bias
features.7.2.conv2.weight
features.7.2.bn2.weight
features.7.2.bn2.bias
classifier.0.weight
tensor([[-4.0103e-02, -3.0650e-03, -1.7690e-02,  1.8196e-02,  1.8940e-02,
         -5.2432e-02,  4.0119e-02,  1.6359e-02, -1.6112e-02,  3.6596e-02,
          1.1273e-02,  1.9971e-02,  2.4494e-02, -4.9495e-02, -2.4485e-02,
         -2.7787e-02, -4.4333e-02,  3.5293e-02, -1.1992e-02, -4.7242e-02,
          2.5577e-02, -3.0916e-02, -8.1142e-03, -2.4182e-03, -1.7207e-02,
          1.7690e-02,  3.6298e-02, -3.3379e-03,  1.5704e-02,  2.7359e-02,
         -1.6443e-04,  2.8785e-02, -4.1732e-03,  1.2785e-02,  6.6502e-02,
          2.3740e-02, -1.9601e-02,  2.0654e-02, -1.5341e-02,  5.1119e-02,
         -9.3292e-03, -4.6272e-02,  4.0541e-03, -1.6780e-02,  2.7696e-02,
         -2.8260e-02, -2.7696e-02, -2.6938e-02, -5.7942e-02, -1.7714e-02,
          1.1813e-02, -3.6892e-02,  3.2378e-02, -7.0157e-02,  2.9828e-03,
          4.5477e-02,  9.4537e-03,  1.4507e-02,  2.3831e-02,  1.3459e-02,
          5.2202e-02, -3.2785e-02, -1.7972e-02,  3.4436e-02, -1.7274e-02,
         -5.2432e-02,  3.5817e-02, -3.7003e-02,  3.9805e-02, -3.2142e-02,
          4.8118e-02,  3.0372e-02,  3.1019e-02,  2.0520e-02, -2.4969e-02,
          1.9197e-02, -5.7083e-02,  1.9971e-02,  4.1865e-03, -1.6170e-02,
         -2.7157e-02,  7.6236e-02,  7.7632e-04,  1.2858e-02,  6.0151e-02,
          4.4184e-02, -2.4115e-02,  3.8711e-03,  3.1806e-02, -7.6332e-03,
         -2.3273e-02, -3.6383e-02, -3.8611e-02,  4.8423e-02,  1.4127e-03,
          2.1603e-02,  5.7787e-03,  1.6463e-02,  2.6809e-02, -1.7455e-02,
         -3.9449e-03, -2.3069e-03,  6.2234e-02, -2.9278e-02, -9.0541e-03,
          4.0190e-02,  3.4797e-02, -3.7103e-03, -2.7790e-02, -4.0176e-02,
          4.5264e-03,  1.3271e-02, -3.0030e-03,  2.4259e-02, -2.2459e-02,
         -8.1706e-03,  2.9898e-02, -4.2108e-02, -3.8830e-03, -1.6584e-03,
         -2.5688e-03,  5.5268e-02,  2.5715e-03, -8.0469e-03, -5.3614e-03,
          3.2183e-02, -1.6894e-02,  4.4844e-02, -5.0618e-03,  3.0681e-03,
          3.4562e-02,  2.0505e-02,  5.2180e-02, -4.3831e-02,  4.5276e-02,
         -8.6777e-03,  4.2830e-02, -2.6109e-02, -4.2990e-03, -7.8125e-02,
          2.6773e-03, -4.4095e-03, -6.4931e-02, -1.2023e-02,  3.0119e-02,
          6.6624e-02,  3.2416e-02,  1.4566e-02, -6.7606e-02, -4.9717e-02,
         -2.2283e-02, -1.9339e-02, -8.6594e-02, -9.5567e-03, -3.6641e-02,
         -2.4269e-02,  6.0891e-03,  3.5978e-02, -1.3062e-02, -2.4939e-02,
         -2.6925e-02, -2.6563e-03, -1.3851e-02, -5.0669e-03,  6.0252e-02,
         -4.9383e-02,  4.5053e-02,  1.6863e-02,  2.6193e-02, -3.0816e-02,
          3.9407e-02, -5.0734e-02, -2.5574e-02,  1.5975e-02,  3.2121e-02,
         -2.1032e-02, -1.1800e-02,  2.0957e-02,  1.7678e-02,  3.2595e-02,
          2.7446e-02,  2.3684e-02,  4.6127e-03, -1.3901e-02, -1.4728e-02,
          5.2495e-02, -3.1069e-02,  1.1383e-02,  5.4902e-02, -9.8808e-03,
          5.5881e-02,  3.7348e-02,  4.9416e-03, -5.1515e-02, -6.1780e-02,
          2.5301e-02, -2.6315e-02, -1.8703e-02,  2.9316e-02, -2.3085e-02,
          4.1748e-03,  5.2638e-02, -2.1772e-02, -1.4431e-03,  9.2481e-03,
         -4.8283e-02,  3.1777e-03,  8.0173e-02,  4.5029e-03, -1.4431e-02,
         -3.3560e-02,  4.6829e-02, -9.9793e-03, -3.2405e-03, -1.1975e-02,
          5.5930e-02,  2.2951e-02, -3.9283e-02,  2.4410e-03,  3.8763e-03,
         -2.1583e-02, -4.4375e-03, -3.1224e-02,  2.5172e-02,  5.3837e-03,
         -4.2036e-02,  3.9306e-02, -1.9016e-02,  8.3942e-02,  1.7185e-02,
         -3.3008e-02,  1.2389e-02, -6.3560e-02,  2.6117e-02, -1.6864e-02,
          5.5821e-02, -1.2640e-03, -4.0874e-02, -1.3435e-02,  2.2319e-02,
         -3.9246e-02,  2.1261e-02, -2.4980e-02, -5.5888e-02, -5.4517e-03,
          2.5954e-02,  4.3477e-02,  5.3093e-02, -1.7187e-03,  5.0169e-02,
         -3.7164e-02,  2.7800e-02,  1.2954e-02,  3.6849e-03, -1.0594e-02,
         -1.0012e-02, -1.7303e-02, -2.1361e-02,  2.9390e-02, -2.0674e-02,
          3.5865e-02,  7.5633e-03,  2.3617e-02, -4.1382e-02,  2.9991e-02,
          3.0470e-04, -2.7890e-02, -1.8794e-02, -1.6211e-02,  3.3249e-02,
          1.9253e-02,  1.9588e-02,  2.2323e-02,  1.7697e-02, -7.8543e-03,
         -7.8463e-03,  2.2645e-02, -3.9645e-02, -3.3896e-02,  2.1476e-02,
          3.9840e-02, -7.6785e-03, -1.7353e-02, -5.0593e-02,  2.9839e-02,
          9.3777e-03, -5.1932e-05,  4.8221e-02,  1.5305e-02,  1.1562e-02,
          1.7175e-02,  1.9763e-02,  2.1498e-02, -6.4569e-03,  4.4490e-02,
         -2.1948e-02,  3.3371e-02, -2.6169e-02,  2.1575e-02,  5.0593e-02,
         -7.6582e-03,  7.3652e-02, -9.0516e-02,  3.0139e-03,  3.1726e-02,
         -5.5917e-02,  2.1198e-02, -2.0204e-02, -3.1504e-02,  8.3176e-03,
          6.8362e-02, -1.7631e-02,  3.2006e-02, -3.2219e-02, -2.3685e-02,
          2.5310e-02,  2.4241e-02,  1.3161e-02,  4.5711e-02,  6.0400e-02,
          3.2823e-02,  3.7996e-02, -6.6305e-03,  6.0794e-03,  3.6651e-02,
         -7.5762e-03,  2.3375e-03,  1.8069e-02, -3.9109e-02,  3.8477e-02,
         -2.9035e-02, -1.5453e-02, -6.7045e-03,  3.4121e-02, -9.8876e-03,
         -8.4192e-06,  5.2210e-03,  1.6493e-02, -6.2000e-02, -6.0531e-04,
          7.3830e-03, -4.0898e-05, -6.0627e-03, -2.8498e-02, -7.6559e-03,
          1.7166e-02, -1.6621e-02,  3.0013e-02, -2.9750e-02,  5.7559e-02,
         -1.9507e-02,  2.1800e-02, -6.5081e-02,  2.9435e-02, -9.1154e-03,
          1.4963e-02,  2.7812e-02,  3.7519e-02,  1.0546e-02, -5.9419e-02,
         -2.2688e-02,  2.9650e-02,  3.1546e-02,  2.7190e-03,  4.2752e-02,
         -9.8702e-03,  4.0616e-02, -8.7285e-03,  4.1171e-02, -2.8747e-03,
         -4.7521e-02,  1.4819e-02,  3.4308e-02, -2.3178e-02,  6.8566e-04,
          6.5807e-02, -1.5936e-02, -5.6867e-02, -2.0194e-02,  1.0089e-02,
          3.2515e-02,  2.7668e-02, -8.3925e-02, -4.6546e-02,  3.0311e-02,
         -4.4808e-02, -6.0378e-02,  2.9398e-02, -4.5278e-03, -3.3444e-02,
          1.7838e-02,  1.8011e-02,  1.1978e-02,  6.9284e-03, -4.7839e-02,
          6.3141e-03, -2.7778e-03, -6.6707e-02,  3.5588e-02,  3.3485e-02,
          4.3899e-02, -4.5265e-02,  6.8920e-03, -1.4840e-02,  2.2699e-02,
         -5.0180e-02, -5.5397e-02,  1.2932e-02, -1.6373e-02, -1.4470e-03,
         -3.1263e-02,  1.2203e-02,  6.5744e-02,  5.2748e-02,  4.9446e-02,
          1.1454e-02, -1.2506e-02,  1.6448e-02,  6.4812e-03,  9.1295e-03,
          4.7307e-02, -1.0583e-02,  2.6514e-02, -2.7043e-02,  2.1754e-02,
         -7.4832e-03,  8.8365e-03,  3.6107e-02,  1.0179e-02,  3.0531e-02,
          2.9151e-02,  2.7932e-02,  2.9156e-02,  2.0762e-02, -2.9590e-02,
         -8.0620e-04, -5.6288e-02,  4.2960e-02,  2.9789e-02,  1.9852e-02,
         -3.9940e-02, -1.6577e-03, -9.1024e-03,  1.2403e-03,  3.8063e-02,
          5.4522e-02,  3.3541e-02,  1.8009e-02, -2.9072e-03, -2.2482e-04,
         -6.2176e-03, -7.1625e-02, -2.9141e-02,  2.6403e-02, -3.7902e-02,
          7.6599e-03,  9.9294e-03,  4.3861e-02,  2.1696e-02,  3.6313e-03,
          5.7046e-02,  8.5943e-02, -1.0694e-02, -6.1576e-02,  3.4494e-02,
         -1.4768e-02, -1.5311e-03,  5.9259e-02,  2.4781e-02,  2.1675e-02,
          7.6399e-02, -9.0391e-03, -1.2334e-02, -2.4607e-02,  5.1892e-02,
         -2.1849e-02, -8.1642e-02,  2.2554e-02, -1.6767e-02,  2.3100e-02,
          1.6989e-02,  9.1284e-04,  1.9720e-02, -3.2603e-02,  3.9477e-03,
         -2.7493e-02,  2.6490e-02,  1.4810e-02, -9.9203e-02, -4.7352e-03,
          8.6435e-03,  7.8281e-02, -3.9165e-02,  3.1929e-02,  2.9405e-02,
          4.3515e-02,  4.6316e-02,  1.0432e-03,  3.8957e-02, -2.6859e-02,
         -4.3090e-03,  1.5592e-02,  5.6056e-02, -3.0031e-02, -7.8511e-03,
         -3.4687e-02,  1.6892e-02, -2.0491e-03, -1.4476e-02,  4.8247e-02,
         -1.3725e-02, -2.7866e-02]], device='cuda:0')
classifier.0.bias

Multiclass code (gradient is none)

original_model = models.__dict__["resnet34"](pretrained=True)
model = ResNet34_Model(original_model, 3)
model = model.cuda()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = 3e-4, weight_decay=0)
criterion = nn.CrossEntropyLoss(reduction='mean').cuda()

X_batch = X_batch.cuda() # X_batch and y_batch are pulled from the first iteration of a data loader
y_batch = y_batch.cuda()

# train for 1 epoch
model.train()
output = model(X_batch).squeeze()
output_var = torch.softmax(output, dim=1)    
_, output_class = torch.max(output_var, 1)
loss = torch.tensor(criterion(output_var, y_batch), requires_grad=True)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# check that gradients exist
for n, p in model.named_parameters():
    print(n)
    if(p.requires_grad) and ("bias" not in n):
        print(p.grad)

Output

features.0.weight
features.1.weight
features.1.bias
features.4.0.conv1.weight
features.4.0.bn1.weight
features.4.0.bn1.bias
features.4.0.conv2.weight
features.4.0.bn2.weight
features.4.0.bn2.bias
features.4.1.conv1.weight
features.4.1.bn1.weight
features.4.1.bn1.bias
features.4.1.conv2.weight
features.4.1.bn2.weight
features.4.1.bn2.bias
features.4.2.conv1.weight
features.4.2.bn1.weight
features.4.2.bn1.bias
features.4.2.conv2.weight
features.4.2.bn2.weight
features.4.2.bn2.bias
features.5.0.conv1.weight
features.5.0.bn1.weight
features.5.0.bn1.bias
features.5.0.conv2.weight
features.5.0.bn2.weight
features.5.0.bn2.bias
features.5.0.downsample.0.weight
features.5.0.downsample.1.weight
features.5.0.downsample.1.bias
features.5.1.conv1.weight
features.5.1.bn1.weight
features.5.1.bn1.bias
features.5.1.conv2.weight
features.5.1.bn2.weight
features.5.1.bn2.bias
features.5.2.conv1.weight
features.5.2.bn1.weight
features.5.2.bn1.bias
features.5.2.conv2.weight
features.5.2.bn2.weight
features.5.2.bn2.bias
features.5.3.conv1.weight
features.5.3.bn1.weight
features.5.3.bn1.bias
features.5.3.conv2.weight
features.5.3.bn2.weight
features.5.3.bn2.bias
features.6.0.conv1.weight
features.6.0.bn1.weight
features.6.0.bn1.bias
features.6.0.conv2.weight
features.6.0.bn2.weight
features.6.0.bn2.bias
features.6.0.downsample.0.weight
features.6.0.downsample.1.weight
features.6.0.downsample.1.bias
features.6.1.conv1.weight
features.6.1.bn1.weight
features.6.1.bn1.bias
features.6.1.conv2.weight
features.6.1.bn2.weight
features.6.1.bn2.bias
features.6.2.conv1.weight
features.6.2.bn1.weight
features.6.2.bn1.bias
features.6.2.conv2.weight
features.6.2.bn2.weight
features.6.2.bn2.bias
features.6.3.conv1.weight
features.6.3.bn1.weight
features.6.3.bn1.bias
features.6.3.conv2.weight
features.6.3.bn2.weight
features.6.3.bn2.bias
features.6.4.conv1.weight
features.6.4.bn1.weight
features.6.4.bn1.bias
features.6.4.conv2.weight
features.6.4.bn2.weight
features.6.4.bn2.bias
features.6.5.conv1.weight
features.6.5.bn1.weight
features.6.5.bn1.bias
features.6.5.conv2.weight
features.6.5.bn2.weight
features.6.5.bn2.bias
features.7.0.conv1.weight
features.7.0.bn1.weight
features.7.0.bn1.bias
features.7.0.conv2.weight
features.7.0.bn2.weight
features.7.0.bn2.bias
features.7.0.downsample.0.weight
features.7.0.downsample.1.weight
features.7.0.downsample.1.bias
features.7.1.conv1.weight
features.7.1.bn1.weight
features.7.1.bn1.bias
features.7.1.conv2.weight
features.7.1.bn2.weight
features.7.1.bn2.bias
features.7.2.conv1.weight
features.7.2.bn1.weight
features.7.2.bn1.bias
features.7.2.conv2.weight
features.7.2.bn2.weight
features.7.2.bn2.bias
classifier.0.weight
None
classifier.0.bias

You are creating a new tensor as your loss, thus detaching the original output of your criterion from the computation graph:

loss = torch.tensor(criterion(output_var, y_batch), requires_grad=True)

Instead just call:

loss = criterion(output, y_batch)
loss.backward()

Note, that nn.CrossEntropyLoss expects raw logits as the model output, so don’t apply the softmax on your output tensor.

Awesome, thanks for your help with this!