I am encountering some very strange behavior with my model. I am trying to take the feature extraction layers of ResNet34 and add a final classifier layer onto the end. When I change the final nn.Linear layer from a binary output to a multiclass output, the gradient disappears and my model fails to update its weights.
PyTorch version 1.0.1
Let me know if any additional information is useful, I’ve never made a post like this before.
Example code:
General setup
import torch
import torch.nn as nn
import torchvision.models as models
class ResNet34_Model(nn.Module):
def __init__(self, original_model, num_classes):
super(TestModel, self).__init__()
linear_size = 512
self.features = nn.Sequential(*list(original_model.children())[:-1])
self.classifier = nn.Sequential(
nn.Linear(linear_size, num_classes)
)
# Freeze those weights
for p in self.features.parameters():
p.requires_grad = False
def forward(self, x):
f = self.features(x)
if self.modelName == 'alexnet' :
f = f.view(f.size(0), 256 * 6 * 6)
elif self.modelName == 'vgg16':
f = f.view(f.size(0), -1)
elif self.modelName == 'resnet' :
f = f.view(f.size(0), -1)
elif self.modelName == "densenet":
# f = f.relu(f, inplace=True)
# f = f.avg_pool2d(f, kernel_size=7).view(f.size(0), -1)
f = f.view(f.size(0), -1)
y = self.classifier(f)
return y
Binary code (gradient is fine)
original_model = models.__dict__["resnet34"](pretrained=True)
model = ResNet34_Model(original_model, 1)
model = model.cuda()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = 3e-4, weight_decay=0)
criterion = nn.BCEWithLogitsLoss().cuda()
X_batch = X_batch.cuda() # X_batch and y_batch are pulled from the first iteration of a data loader
y_batch = y_batch.cuda()
# train for 1 epoch
model.train()
output = model(X_batch).squeeze()
loss = criterion(output.float(), y_batch.float())
output_var = torch.sigmoid(output).data
output_class = [ 0 if a<=0.5 else 1 for a in output_var]
optimizer.zero_grad()
loss.backward()
optimizer.step()
# check that gradients exist
for n, p in model.named_parameters():
print(n)
if(p.requires_grad) and ("bias" not in n):
print(p.grad)
Output
features.0.weight
features.1.weight
features.1.bias
features.4.0.conv1.weight
features.4.0.bn1.weight
features.4.0.bn1.bias
features.4.0.conv2.weight
features.4.0.bn2.weight
features.4.0.bn2.bias
features.4.1.conv1.weight
features.4.1.bn1.weight
features.4.1.bn1.bias
features.4.1.conv2.weight
features.4.1.bn2.weight
features.4.1.bn2.bias
features.4.2.conv1.weight
features.4.2.bn1.weight
features.4.2.bn1.bias
features.4.2.conv2.weight
features.4.2.bn2.weight
features.4.2.bn2.bias
features.5.0.conv1.weight
features.5.0.bn1.weight
features.5.0.bn1.bias
features.5.0.conv2.weight
features.5.0.bn2.weight
features.5.0.bn2.bias
features.5.0.downsample.0.weight
features.5.0.downsample.1.weight
features.5.0.downsample.1.bias
features.5.1.conv1.weight
features.5.1.bn1.weight
features.5.1.bn1.bias
features.5.1.conv2.weight
features.5.1.bn2.weight
features.5.1.bn2.bias
features.5.2.conv1.weight
features.5.2.bn1.weight
features.5.2.bn1.bias
features.5.2.conv2.weight
features.5.2.bn2.weight
features.5.2.bn2.bias
features.5.3.conv1.weight
features.5.3.bn1.weight
features.5.3.bn1.bias
features.5.3.conv2.weight
features.5.3.bn2.weight
features.5.3.bn2.bias
features.6.0.conv1.weight
features.6.0.bn1.weight
features.6.0.bn1.bias
features.6.0.conv2.weight
features.6.0.bn2.weight
features.6.0.bn2.bias
features.6.0.downsample.0.weight
features.6.0.downsample.1.weight
features.6.0.downsample.1.bias
features.6.1.conv1.weight
features.6.1.bn1.weight
features.6.1.bn1.bias
features.6.1.conv2.weight
features.6.1.bn2.weight
features.6.1.bn2.bias
features.6.2.conv1.weight
features.6.2.bn1.weight
features.6.2.bn1.bias
features.6.2.conv2.weight
features.6.2.bn2.weight
features.6.2.bn2.bias
features.6.3.conv1.weight
features.6.3.bn1.weight
features.6.3.bn1.bias
features.6.3.conv2.weight
features.6.3.bn2.weight
features.6.3.bn2.bias
features.6.4.conv1.weight
features.6.4.bn1.weight
features.6.4.bn1.bias
features.6.4.conv2.weight
features.6.4.bn2.weight
features.6.4.bn2.bias
features.6.5.conv1.weight
features.6.5.bn1.weight
features.6.5.bn1.bias
features.6.5.conv2.weight
features.6.5.bn2.weight
features.6.5.bn2.bias
features.7.0.conv1.weight
features.7.0.bn1.weight
features.7.0.bn1.bias
features.7.0.conv2.weight
features.7.0.bn2.weight
features.7.0.bn2.bias
features.7.0.downsample.0.weight
features.7.0.downsample.1.weight
features.7.0.downsample.1.bias
features.7.1.conv1.weight
features.7.1.bn1.weight
features.7.1.bn1.bias
features.7.1.conv2.weight
features.7.1.bn2.weight
features.7.1.bn2.bias
features.7.2.conv1.weight
features.7.2.bn1.weight
features.7.2.bn1.bias
features.7.2.conv2.weight
features.7.2.bn2.weight
features.7.2.bn2.bias
classifier.0.weight
tensor([[-4.0103e-02, -3.0650e-03, -1.7690e-02, 1.8196e-02, 1.8940e-02,
-5.2432e-02, 4.0119e-02, 1.6359e-02, -1.6112e-02, 3.6596e-02,
1.1273e-02, 1.9971e-02, 2.4494e-02, -4.9495e-02, -2.4485e-02,
-2.7787e-02, -4.4333e-02, 3.5293e-02, -1.1992e-02, -4.7242e-02,
2.5577e-02, -3.0916e-02, -8.1142e-03, -2.4182e-03, -1.7207e-02,
1.7690e-02, 3.6298e-02, -3.3379e-03, 1.5704e-02, 2.7359e-02,
-1.6443e-04, 2.8785e-02, -4.1732e-03, 1.2785e-02, 6.6502e-02,
2.3740e-02, -1.9601e-02, 2.0654e-02, -1.5341e-02, 5.1119e-02,
-9.3292e-03, -4.6272e-02, 4.0541e-03, -1.6780e-02, 2.7696e-02,
-2.8260e-02, -2.7696e-02, -2.6938e-02, -5.7942e-02, -1.7714e-02,
1.1813e-02, -3.6892e-02, 3.2378e-02, -7.0157e-02, 2.9828e-03,
4.5477e-02, 9.4537e-03, 1.4507e-02, 2.3831e-02, 1.3459e-02,
5.2202e-02, -3.2785e-02, -1.7972e-02, 3.4436e-02, -1.7274e-02,
-5.2432e-02, 3.5817e-02, -3.7003e-02, 3.9805e-02, -3.2142e-02,
4.8118e-02, 3.0372e-02, 3.1019e-02, 2.0520e-02, -2.4969e-02,
1.9197e-02, -5.7083e-02, 1.9971e-02, 4.1865e-03, -1.6170e-02,
-2.7157e-02, 7.6236e-02, 7.7632e-04, 1.2858e-02, 6.0151e-02,
4.4184e-02, -2.4115e-02, 3.8711e-03, 3.1806e-02, -7.6332e-03,
-2.3273e-02, -3.6383e-02, -3.8611e-02, 4.8423e-02, 1.4127e-03,
2.1603e-02, 5.7787e-03, 1.6463e-02, 2.6809e-02, -1.7455e-02,
-3.9449e-03, -2.3069e-03, 6.2234e-02, -2.9278e-02, -9.0541e-03,
4.0190e-02, 3.4797e-02, -3.7103e-03, -2.7790e-02, -4.0176e-02,
4.5264e-03, 1.3271e-02, -3.0030e-03, 2.4259e-02, -2.2459e-02,
-8.1706e-03, 2.9898e-02, -4.2108e-02, -3.8830e-03, -1.6584e-03,
-2.5688e-03, 5.5268e-02, 2.5715e-03, -8.0469e-03, -5.3614e-03,
3.2183e-02, -1.6894e-02, 4.4844e-02, -5.0618e-03, 3.0681e-03,
3.4562e-02, 2.0505e-02, 5.2180e-02, -4.3831e-02, 4.5276e-02,
-8.6777e-03, 4.2830e-02, -2.6109e-02, -4.2990e-03, -7.8125e-02,
2.6773e-03, -4.4095e-03, -6.4931e-02, -1.2023e-02, 3.0119e-02,
6.6624e-02, 3.2416e-02, 1.4566e-02, -6.7606e-02, -4.9717e-02,
-2.2283e-02, -1.9339e-02, -8.6594e-02, -9.5567e-03, -3.6641e-02,
-2.4269e-02, 6.0891e-03, 3.5978e-02, -1.3062e-02, -2.4939e-02,
-2.6925e-02, -2.6563e-03, -1.3851e-02, -5.0669e-03, 6.0252e-02,
-4.9383e-02, 4.5053e-02, 1.6863e-02, 2.6193e-02, -3.0816e-02,
3.9407e-02, -5.0734e-02, -2.5574e-02, 1.5975e-02, 3.2121e-02,
-2.1032e-02, -1.1800e-02, 2.0957e-02, 1.7678e-02, 3.2595e-02,
2.7446e-02, 2.3684e-02, 4.6127e-03, -1.3901e-02, -1.4728e-02,
5.2495e-02, -3.1069e-02, 1.1383e-02, 5.4902e-02, -9.8808e-03,
5.5881e-02, 3.7348e-02, 4.9416e-03, -5.1515e-02, -6.1780e-02,
2.5301e-02, -2.6315e-02, -1.8703e-02, 2.9316e-02, -2.3085e-02,
4.1748e-03, 5.2638e-02, -2.1772e-02, -1.4431e-03, 9.2481e-03,
-4.8283e-02, 3.1777e-03, 8.0173e-02, 4.5029e-03, -1.4431e-02,
-3.3560e-02, 4.6829e-02, -9.9793e-03, -3.2405e-03, -1.1975e-02,
5.5930e-02, 2.2951e-02, -3.9283e-02, 2.4410e-03, 3.8763e-03,
-2.1583e-02, -4.4375e-03, -3.1224e-02, 2.5172e-02, 5.3837e-03,
-4.2036e-02, 3.9306e-02, -1.9016e-02, 8.3942e-02, 1.7185e-02,
-3.3008e-02, 1.2389e-02, -6.3560e-02, 2.6117e-02, -1.6864e-02,
5.5821e-02, -1.2640e-03, -4.0874e-02, -1.3435e-02, 2.2319e-02,
-3.9246e-02, 2.1261e-02, -2.4980e-02, -5.5888e-02, -5.4517e-03,
2.5954e-02, 4.3477e-02, 5.3093e-02, -1.7187e-03, 5.0169e-02,
-3.7164e-02, 2.7800e-02, 1.2954e-02, 3.6849e-03, -1.0594e-02,
-1.0012e-02, -1.7303e-02, -2.1361e-02, 2.9390e-02, -2.0674e-02,
3.5865e-02, 7.5633e-03, 2.3617e-02, -4.1382e-02, 2.9991e-02,
3.0470e-04, -2.7890e-02, -1.8794e-02, -1.6211e-02, 3.3249e-02,
1.9253e-02, 1.9588e-02, 2.2323e-02, 1.7697e-02, -7.8543e-03,
-7.8463e-03, 2.2645e-02, -3.9645e-02, -3.3896e-02, 2.1476e-02,
3.9840e-02, -7.6785e-03, -1.7353e-02, -5.0593e-02, 2.9839e-02,
9.3777e-03, -5.1932e-05, 4.8221e-02, 1.5305e-02, 1.1562e-02,
1.7175e-02, 1.9763e-02, 2.1498e-02, -6.4569e-03, 4.4490e-02,
-2.1948e-02, 3.3371e-02, -2.6169e-02, 2.1575e-02, 5.0593e-02,
-7.6582e-03, 7.3652e-02, -9.0516e-02, 3.0139e-03, 3.1726e-02,
-5.5917e-02, 2.1198e-02, -2.0204e-02, -3.1504e-02, 8.3176e-03,
6.8362e-02, -1.7631e-02, 3.2006e-02, -3.2219e-02, -2.3685e-02,
2.5310e-02, 2.4241e-02, 1.3161e-02, 4.5711e-02, 6.0400e-02,
3.2823e-02, 3.7996e-02, -6.6305e-03, 6.0794e-03, 3.6651e-02,
-7.5762e-03, 2.3375e-03, 1.8069e-02, -3.9109e-02, 3.8477e-02,
-2.9035e-02, -1.5453e-02, -6.7045e-03, 3.4121e-02, -9.8876e-03,
-8.4192e-06, 5.2210e-03, 1.6493e-02, -6.2000e-02, -6.0531e-04,
7.3830e-03, -4.0898e-05, -6.0627e-03, -2.8498e-02, -7.6559e-03,
1.7166e-02, -1.6621e-02, 3.0013e-02, -2.9750e-02, 5.7559e-02,
-1.9507e-02, 2.1800e-02, -6.5081e-02, 2.9435e-02, -9.1154e-03,
1.4963e-02, 2.7812e-02, 3.7519e-02, 1.0546e-02, -5.9419e-02,
-2.2688e-02, 2.9650e-02, 3.1546e-02, 2.7190e-03, 4.2752e-02,
-9.8702e-03, 4.0616e-02, -8.7285e-03, 4.1171e-02, -2.8747e-03,
-4.7521e-02, 1.4819e-02, 3.4308e-02, -2.3178e-02, 6.8566e-04,
6.5807e-02, -1.5936e-02, -5.6867e-02, -2.0194e-02, 1.0089e-02,
3.2515e-02, 2.7668e-02, -8.3925e-02, -4.6546e-02, 3.0311e-02,
-4.4808e-02, -6.0378e-02, 2.9398e-02, -4.5278e-03, -3.3444e-02,
1.7838e-02, 1.8011e-02, 1.1978e-02, 6.9284e-03, -4.7839e-02,
6.3141e-03, -2.7778e-03, -6.6707e-02, 3.5588e-02, 3.3485e-02,
4.3899e-02, -4.5265e-02, 6.8920e-03, -1.4840e-02, 2.2699e-02,
-5.0180e-02, -5.5397e-02, 1.2932e-02, -1.6373e-02, -1.4470e-03,
-3.1263e-02, 1.2203e-02, 6.5744e-02, 5.2748e-02, 4.9446e-02,
1.1454e-02, -1.2506e-02, 1.6448e-02, 6.4812e-03, 9.1295e-03,
4.7307e-02, -1.0583e-02, 2.6514e-02, -2.7043e-02, 2.1754e-02,
-7.4832e-03, 8.8365e-03, 3.6107e-02, 1.0179e-02, 3.0531e-02,
2.9151e-02, 2.7932e-02, 2.9156e-02, 2.0762e-02, -2.9590e-02,
-8.0620e-04, -5.6288e-02, 4.2960e-02, 2.9789e-02, 1.9852e-02,
-3.9940e-02, -1.6577e-03, -9.1024e-03, 1.2403e-03, 3.8063e-02,
5.4522e-02, 3.3541e-02, 1.8009e-02, -2.9072e-03, -2.2482e-04,
-6.2176e-03, -7.1625e-02, -2.9141e-02, 2.6403e-02, -3.7902e-02,
7.6599e-03, 9.9294e-03, 4.3861e-02, 2.1696e-02, 3.6313e-03,
5.7046e-02, 8.5943e-02, -1.0694e-02, -6.1576e-02, 3.4494e-02,
-1.4768e-02, -1.5311e-03, 5.9259e-02, 2.4781e-02, 2.1675e-02,
7.6399e-02, -9.0391e-03, -1.2334e-02, -2.4607e-02, 5.1892e-02,
-2.1849e-02, -8.1642e-02, 2.2554e-02, -1.6767e-02, 2.3100e-02,
1.6989e-02, 9.1284e-04, 1.9720e-02, -3.2603e-02, 3.9477e-03,
-2.7493e-02, 2.6490e-02, 1.4810e-02, -9.9203e-02, -4.7352e-03,
8.6435e-03, 7.8281e-02, -3.9165e-02, 3.1929e-02, 2.9405e-02,
4.3515e-02, 4.6316e-02, 1.0432e-03, 3.8957e-02, -2.6859e-02,
-4.3090e-03, 1.5592e-02, 5.6056e-02, -3.0031e-02, -7.8511e-03,
-3.4687e-02, 1.6892e-02, -2.0491e-03, -1.4476e-02, 4.8247e-02,
-1.3725e-02, -2.7866e-02]], device='cuda:0')
classifier.0.bias
Multiclass code (gradient is none)
original_model = models.__dict__["resnet34"](pretrained=True)
model = ResNet34_Model(original_model, 3)
model = model.cuda()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = 3e-4, weight_decay=0)
criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
X_batch = X_batch.cuda() # X_batch and y_batch are pulled from the first iteration of a data loader
y_batch = y_batch.cuda()
# train for 1 epoch
model.train()
output = model(X_batch).squeeze()
output_var = torch.softmax(output, dim=1)
_, output_class = torch.max(output_var, 1)
loss = torch.tensor(criterion(output_var, y_batch), requires_grad=True)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# check that gradients exist
for n, p in model.named_parameters():
print(n)
if(p.requires_grad) and ("bias" not in n):
print(p.grad)
Output
features.0.weight
features.1.weight
features.1.bias
features.4.0.conv1.weight
features.4.0.bn1.weight
features.4.0.bn1.bias
features.4.0.conv2.weight
features.4.0.bn2.weight
features.4.0.bn2.bias
features.4.1.conv1.weight
features.4.1.bn1.weight
features.4.1.bn1.bias
features.4.1.conv2.weight
features.4.1.bn2.weight
features.4.1.bn2.bias
features.4.2.conv1.weight
features.4.2.bn1.weight
features.4.2.bn1.bias
features.4.2.conv2.weight
features.4.2.bn2.weight
features.4.2.bn2.bias
features.5.0.conv1.weight
features.5.0.bn1.weight
features.5.0.bn1.bias
features.5.0.conv2.weight
features.5.0.bn2.weight
features.5.0.bn2.bias
features.5.0.downsample.0.weight
features.5.0.downsample.1.weight
features.5.0.downsample.1.bias
features.5.1.conv1.weight
features.5.1.bn1.weight
features.5.1.bn1.bias
features.5.1.conv2.weight
features.5.1.bn2.weight
features.5.1.bn2.bias
features.5.2.conv1.weight
features.5.2.bn1.weight
features.5.2.bn1.bias
features.5.2.conv2.weight
features.5.2.bn2.weight
features.5.2.bn2.bias
features.5.3.conv1.weight
features.5.3.bn1.weight
features.5.3.bn1.bias
features.5.3.conv2.weight
features.5.3.bn2.weight
features.5.3.bn2.bias
features.6.0.conv1.weight
features.6.0.bn1.weight
features.6.0.bn1.bias
features.6.0.conv2.weight
features.6.0.bn2.weight
features.6.0.bn2.bias
features.6.0.downsample.0.weight
features.6.0.downsample.1.weight
features.6.0.downsample.1.bias
features.6.1.conv1.weight
features.6.1.bn1.weight
features.6.1.bn1.bias
features.6.1.conv2.weight
features.6.1.bn2.weight
features.6.1.bn2.bias
features.6.2.conv1.weight
features.6.2.bn1.weight
features.6.2.bn1.bias
features.6.2.conv2.weight
features.6.2.bn2.weight
features.6.2.bn2.bias
features.6.3.conv1.weight
features.6.3.bn1.weight
features.6.3.bn1.bias
features.6.3.conv2.weight
features.6.3.bn2.weight
features.6.3.bn2.bias
features.6.4.conv1.weight
features.6.4.bn1.weight
features.6.4.bn1.bias
features.6.4.conv2.weight
features.6.4.bn2.weight
features.6.4.bn2.bias
features.6.5.conv1.weight
features.6.5.bn1.weight
features.6.5.bn1.bias
features.6.5.conv2.weight
features.6.5.bn2.weight
features.6.5.bn2.bias
features.7.0.conv1.weight
features.7.0.bn1.weight
features.7.0.bn1.bias
features.7.0.conv2.weight
features.7.0.bn2.weight
features.7.0.bn2.bias
features.7.0.downsample.0.weight
features.7.0.downsample.1.weight
features.7.0.downsample.1.bias
features.7.1.conv1.weight
features.7.1.bn1.weight
features.7.1.bn1.bias
features.7.1.conv2.weight
features.7.1.bn2.weight
features.7.1.bn2.bias
features.7.2.conv1.weight
features.7.2.bn1.weight
features.7.2.bn1.bias
features.7.2.conv2.weight
features.7.2.bn2.weight
features.7.2.bn2.bias
classifier.0.weight
None
classifier.0.bias