Weights and bias not updated

Hi, I’m quite new to deep learning and I’m trying to do a transfer learning on a pretrained 3d convolutional network that classifies brain images into multiple labels (i.e., the behavioural task a participant performed while they were scanned). I loaded the pretrained model and froze parameters in every layer except for the last convolutional layer and the fully connected layer but the weights and biases were not updated during training.

The pretrained model structure was:

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    def __init__(self, n_in, n_out, stride = 1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv3d(n_in, n_out, kernel_size = 3, stride = stride, padding = 1)
        self.bn1 = nn.BatchNorm3d(n_out)
        self.relu = nn.ReLU(inplace = True)

        self.conv2 = nn.Conv3d(n_out, n_out, kernel_size = 3, padding = 1)
        self.bn2 = nn.BatchNorm3d(n_out)
        self.relu2 = nn.ReLU(inplace = True) 

        if stride != 1 or n_out != n_in:
            self.shortcut = nn.Sequential(
                nn.Conv3d(n_in, n_out, kernel_size = 1, stride = stride),
                nn.BatchNorm3d(n_out))
        else:
            self.shortcut = None

    def forward(self, x):
        residual = x
        if self.shortcut is not None:
            residual = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += residual
        out = self.relu2(out)
        return out
    
class DeepBrain(nn.Module):
    def __init__(self, block=BasicBlock, inplanes=27, planes=3, drop_out=True):
        super(DeepBrain, self).__init__()
        self.n_classes = 7
        
        self.preBlock = nn.Sequential(
            nn.Conv3d(inplanes, planes, kernel_size=1, padding=0),
            nn.BatchNorm3d(planes),
            nn.ReLU(inplace=True),
            
            nn.Conv3d(planes, 24, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(24),
            nn.ReLU(inplace = True))
        
        self.layer_1 = self._make_layer(block,  24, 32, 2)
        self.layer_2 = self._make_layer(block, 32, 64, 2, pooling=True)
        self.layer_3 = self._make_layer(block, 64, 64, 2, pooling=True)
        self.layer_4 = self._make_layer(block, 64, 128, 2, pooling=True)
        
        self.post_conv = nn.Conv3d(128, 64, kernel_size=(5, 6, 6))
        
        if drop_out:
            self.classifier = nn.Sequential(
                nn.Linear(64, 64),
                nn.ReLU(inplace=True),
                nn.Dropout(),
                nn.Linear(64, self.n_classes),
                nn.LogSoftmax())
        else:
            self.classifier = nn.Sequential(
                nn.Linear(64, 64),
                nn.ReLU(inplace=True),
                nn.Linear(64, self.n_classes),
                nn.LogSoftmax())            
        
        self._initialize_weights()
        
    def _make_layer(self, block, planes_in, planes_out, num_blocks, pooling=False, drop_out=False):
        layers = []
        if pooling:
#             layers.append(nn.MaxPool3d(kernel_size=2, stride=2))
            layers.append(block(planes_in, planes_out, stride=2))
        else:
            layers.append(block(planes_in, planes_out))
        for i in range(num_blocks - 1):
            layers.append(block(planes_out, planes_out))
            
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform(m.weight, gain=nn.init.calculate_gain('relu'))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
                
    def forward(self, x):
        
        x = self.preBlock(x)

        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.post_conv(x)
        x = x.view(-1, 64 * 1)
        
        x = self.classifier(x)
        
        return x

The code I used to load the model state was:

def get_pretrained_model(n_classes,drop_out,train_on_gpu,multi_gpu):

    import os
    import sys
    import glob
    import numpy as np
    import pandas as pd

    import torch
    from torch import nn
    from model_o import DeepBrain

    os.environ["CUDA_VISIBLE_DEVICES"]="0"

    model = DeepBrain()
    if train_on_gpu:
        model.load_state_dict(torch.load('./checkpoint/checkpoint_o.pth.tar',map_location="cuda")['state_dict'])
    else:
        model.load_state_dict(torch.load('./checkpoint/checkpoint_o.pth.tar',map_location="cpu")['state_dict'])

    # freeze the parameters
    
    for param in model.parameters():
        param.requires_grad = False
    
    
    # reinitialize the weights of the fully connected layers and the last conv layer

    n_inputs = 64

    model.n_classes = n_classes
    model.post_conv = nn.Conv3d(128, n_inputs, kernel_size=(5, 6, 5)) # changed kernel size to 5 x 6 x 5 from original 5 x 6 x 6 to accomodate the change in input size
    
    if drop_out:
        model.classifier[0] = nn.Linear(n_inputs,n_inputs)
        model.classifier[3] = nn.Linear(64, n_classes)
    else:
        model.classifier[0] = nn.Linear(n_inputs,n_inputs)
        model.classifier[2] = nn.Linear(64, n_classes)

    total_params = sum(p.numel() for p in model.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')


    # Move to gpu and parallelize
    
    if train_on_gpu:
        model = model.to('cuda')
    else:
        model = model.to('cpu')


    if multi_gpu:
        model = nn.DataParallel(model)


    return model

I ran these code before training to make sure the last few layers indeed required gradient updates:

for name, param in model.named_parameters():
    print(name, param.requires_grad)

The code above outputs:

preBlock.0.weight False
preBlock.0.bias False
preBlock.1.weight False
preBlock.1.bias False
preBlock.3.weight False
preBlock.3.bias False
preBlock.4.weight False
preBlock.4.bias False
layer_1.0.conv1.weight False
layer_1.0.conv1.bias False
layer_1.0.bn1.weight False
layer_1.0.bn1.bias False
layer_1.0.conv2.weight False
layer_1.0.conv2.bias False
layer_1.0.bn2.weight False
layer_1.0.bn2.bias False
layer_1.0.shortcut.0.weight False
layer_1.0.shortcut.0.bias False
layer_1.0.shortcut.1.weight False
layer_1.0.shortcut.1.bias False
layer_1.1.conv1.weight False
layer_1.1.conv1.bias False
layer_1.1.bn1.weight False
layer_1.1.bn1.bias False
layer_1.1.conv2.weight False
layer_1.1.conv2.bias False
layer_1.1.bn2.weight False
layer_1.1.bn2.bias False
layer_2.0.conv1.weight False
layer_2.0.conv1.bias False
layer_2.0.bn1.weight False
layer_2.0.bn1.bias False
layer_2.0.conv2.weight False
layer_2.0.conv2.bias False
layer_2.0.bn2.weight False
layer_2.0.bn2.bias False
layer_2.0.shortcut.0.weight False
layer_2.0.shortcut.0.bias False
layer_2.0.shortcut.1.weight False
layer_2.0.shortcut.1.bias False
layer_2.1.conv1.weight False
layer_2.1.conv1.bias False
layer_2.1.bn1.weight False
layer_2.1.bn1.bias False
layer_2.1.conv2.weight False
layer_2.1.conv2.bias False
layer_2.1.bn2.weight False
layer_2.1.bn2.bias False
layer_3.0.conv1.weight False
layer_3.0.conv1.bias False
layer_3.0.bn1.weight False
layer_3.0.bn1.bias False
layer_3.0.conv2.weight False
layer_3.0.conv2.bias False
layer_3.0.bn2.weight False
layer_3.0.bn2.bias False
layer_3.0.shortcut.0.weight False
layer_3.0.shortcut.0.bias False
layer_3.0.shortcut.1.weight False
layer_3.0.shortcut.1.bias False
layer_3.1.conv1.weight False
layer_3.1.conv1.bias False
layer_3.1.bn1.weight False
layer_3.1.bn1.bias False
layer_3.1.conv2.weight False
layer_3.1.conv2.bias False
layer_3.1.bn2.weight False
layer_3.1.bn2.bias False
layer_4.0.conv1.weight False
layer_4.0.conv1.bias False
layer_4.0.bn1.weight False
layer_4.0.bn1.bias False
layer_4.0.conv2.weight False
layer_4.0.conv2.bias False
layer_4.0.bn2.weight False
layer_4.0.bn2.bias False
layer_4.0.shortcut.0.weight False
layer_4.0.shortcut.0.bias False
layer_4.0.shortcut.1.weight False
layer_4.0.shortcut.1.bias False
layer_4.1.conv1.weight False
layer_4.1.conv1.bias False
layer_4.1.bn1.weight False
layer_4.1.bn1.bias False
layer_4.1.conv2.weight False
layer_4.1.conv2.bias False
layer_4.1.bn2.weight False
layer_4.1.bn2.bias False
post_conv.weight True
post_conv.bias True
classifier.0.weight True
classifier.0.bias True
classifier.3.weight True
classifier.3.bias True

However, the weights and biases were not updated at all during training and ofc the training and validaation accuracy were both at chance (25% for predicting one out of four labels).

Could anyone have a look what might be the issue? These are the code I used for traning:

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from timeit import default_timer as timer

def train(model,
          train_on_gpu,
          criterion,
          optimizer,
          train_loader,
          valid_loader,
          save_file_name,
          max_epochs_stop=3,
          n_epochs=100,
          print_every=2):
    """Train a PyTorch Model

    Params
    --------
        model (PyTorch model): cnn to train
        criterion (PyTorch loss): objective to minimize
        optimizer (PyTorch optimizier): optimizer to compute gradients of model parameters
        train_loader (PyTorch dataloader): training dataloader to iterate through
        valid_loader (PyTorch dataloader): validation dataloader used for early stopping
        save_file_name (str ending in '.pt'): file path to save the model state dict
        max_epochs_stop (int): maximum number of epochs with no improvement in validation loss for early stopping
        n_epochs (int): maximum number of training epochs
        print_every (int): frequency of epochs to print training stats

    Returns
    --------
        model (PyTorch model): trained cnn with best weights
        history (DataFrame): history of train and validation loss and accuracy
    """

    # Early stopping intialization
    epochs_no_improve = 0
    valid_loss_min = np.Inf

    valid_max_acc = 0
    history = []

    # Number of epochs already trained (if using loaded in model weights)
    try:
        print(f'Model has been trained for: {model.epochs} epochs.\n')
    except:
        model.epochs = 0
        print(f'Starting Training from Scratch.\n')

    overall_start = timer()

    # Main loop
    for epoch in range(n_epochs):

        # keep track of training and validation loss each epoch
        train_loss = 0.0
        valid_loss = 0.0

        train_acc = 0
        valid_acc = 0

        # Set to training
        model.train()
        start = timer()

        w_previous=1
        w = 1

        b_previous=1
        b = 1

        # Training loop
        for ii, (data, target) in enumerate(train_loader):
            # Tensors to gpu
            if train_on_gpu:
                data, target = data.to(device = 'cuda'), target.to(device = 'cuda')
            else:
                data, target = data.to(device = 'cpu'), target.to(device = 'cpu')
            
            data = data.float()

            # Clear gradients
            optimizer.zero_grad()
            # Predicted outputs are log probabilities
            output = model(data)

            # Loss and backpropagation of gradients
            loss = criterion(output, target)
            loss.backward()

            #grad = model.preBlock[0].weight.grad
            #print(grad)

            # Update the parameters
            optimizer.step()

            w = model.post_conv._parameters['weight'].detach()
            #w = model.preBlock[0]._parameters['weight'].detach()
            print(w-w_previous)
            w_previous = w

            b = model.post_conv._parameters["bias"].detach()
            print(b-b_previous)
            b_previous = b


            # Track train loss by multiplying average loss by number of examples in batch
            train_loss += loss.item() * data.size(0)

            # Calculate accuracy by finding max log probability
            _, pred = torch.max(output, dim=1)
            correct_tensor = pred.eq(target.data.view_as(pred))
            # Need to convert correct tensor from int to float to average
            accuracy = torch.mean(correct_tensor.type(torch.FloatTensor))
            # Multiply average accuracy times the number of examples in batch
            train_acc += accuracy.item() * data.size(0)

            # Track training progress
            print(
                f'Epoch: {epoch}\t{100 * (ii + 1) / len(train_loader):.2f}% complete. {timer() - start:.2f} seconds elapsed in epoch.',
                end='\r')

        # After training loops ends, start validation
        else:
            model.epochs += 1

            # Don't need to keep track of gradients
            with torch.no_grad():
                # Set to evaluation mode
                model.eval()

                # Validation loop
                for data, target in valid_loader:
                    # Tensors to gpu
                    if train_on_gpu:
                        data, target = data.to(device = 'cuda'), target.to(device = 'cuda')
                    else:
                        data, target = data.to(device = 'cpu'), target.to(device = 'cpu')
                    
                    data = data.float()

                    # Forward pass
                    output = model(data)

                    # Validation loss
                    loss = criterion(output, target)
                    # Multiply average loss times the number of examples in batch
                    valid_loss += loss.item() * data.size(0)

                    # Calculate validation accuracy
                    _, pred = torch.max(output, dim=1)
                    correct_tensor = pred.eq(target.data.view_as(pred))
                    accuracy = torch.mean(
                        correct_tensor.type(torch.FloatTensor))
                    # Multiply average accuracy times the number of examples
                    valid_acc += accuracy.item() * data.size(0)

                # Calculate average losses
                train_loss = train_loss / len(train_loader.dataset)
                valid_loss = valid_loss / len(valid_loader.dataset)

                # Calculate average accuracy
                train_acc = train_acc / len(train_loader.dataset)
                valid_acc = valid_acc / len(valid_loader.dataset)

                history.append([train_loss, valid_loss, train_acc, valid_acc])

                # Print training and validation results
                if (epoch + 1) % print_every == 0:
                    print(
                        f'\nEpoch: {epoch} \tTraining Loss: {train_loss:.4f} \tValidation Loss: {valid_loss:.4f}'
                    )
                    print(
                        f'\t\tTraining Accuracy: {100 * train_acc:.2f}%\t Validation Accuracy: {100 * valid_acc:.2f}%'
                    )

                # Save the model if validation loss decreases
                if valid_loss < valid_loss_min:
                    # Save model
                    torch.save(model.state_dict(), save_file_name)
                    # Track improvement
                    epochs_no_improve = 0
                    valid_loss_min = valid_loss
                    valid_best_acc = valid_acc
                    best_epoch = epoch

                # Otherwise increment count of epochs with no improvement
                else:
                    epochs_no_improve += 1
                    # Trigger early stopping
                    if epochs_no_improve >= max_epochs_stop:
                        print(
                            f'\nEarly Stopping! Total epochs: {epoch}. Best epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_best_acc:.2f}%'
                        )
                        total_time = timer() - overall_start
                        print(
                            f'{total_time:.2f} total seconds elapsed. {total_time / (epoch+1):.2f} seconds per epoch.'
                        )

                        # Load the best state dict
                        model.load_state_dict(torch.load(save_file_name))
                        # Attach the optimizer
                        model.optimizer = optimizer

                        # Format history
                        history = pd.DataFrame(
                            history,
                            columns=[
                                'train_loss', 'valid_loss', 'train_acc',
                                'valid_acc'
                            ])
                        return model, history

    # Attach the optimizer
    model.optimizer = optimizer
    # Record overall time and print out stats
    total_time = timer() - overall_start
    print(
        f'\nBest epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_best_acc:.2f}%'
    )
    print(
        f'{total_time:.2f} total seconds elapsed. {total_time / (epoch):.2f} seconds per epoch.'
    )
    # Format history
    history = pd.DataFrame(
        history,
        columns=['train_loss', 'valid_loss', 'train_acc', 'valid_acc'])
    return model, history

In the code above I used these to track the differences in weights and biases in each batch compared to the previous one, which outputs tensors with all 0’s (for example in the last conv layer)!

w = model.post_conv._parameters['weight'].detach()
print(w-w_previous)
w_previous = w

I also tried to unfreeze all layers and train but the weights and biases were not even updated in the first conv layer model.preBlock[0].

Thank you so much for your help!

Your model works fine for me using:

model = DeepBrain()
criterion = nn.CrossEntropyLoss()

for param in model.parameters():
    param.requires_grad = False

n_inputs = 64
model.post_conv = nn.Conv3d(128, n_inputs, kernel_size=(5, 6, 5)) 

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

h = 81
data = torch.randn(1, 27, h, h, h)

# Clear gradients
optimizer.zero_grad()
# Predicted outputs are log probabilities
output = model(data)

target = torch.randint(0, 7, (1,))
# Loss and backpropagation of gradients
loss = criterion(output, target)
loss.backward()

print(sum([p.grad.abs().sum() for p in model.parameters() if p.requires_grad==True]))
# tensor(1184.2426)

w0 = model.post_conv.weight.clone()

# Update the parameters
optimizer.step()

w1 = model.post_conv.weight.clone()

print((w1 - w0).abs().sum())
# tensor(1190.6450, grad_fn=<SumBackward0>)

and the trainable parameters are updated.
I needed to change the in_features if the first linear layer in self.classifier as you didn’t post the input shape.

Indeed it seems the weights were updated checking with .copy(), thanks a lot!!