A Problem with Freezing the Neural Net Layers (Transfer Learning)

Hi,

it’s my first post, so please do excuse me if I mess something up.

I am working on this problem, which includes training a simple pytorch neural network on a certain dataset, then freezing the network and retraining it slightly on a different dataset after adding an extra layer. Training the network on the initial dataset works fine, but, after freezing the network and adding an extra layer, it does not seem to train on the second dataset (the cost function does not change and the weights are not updated). Training works fine if the layers are not frozen though, so I assume there are no major problems with the architecture of the neural net (?).

Here’s a simplified version of the code.

Defining the network and the first stage of training:

## Importing the packages:
import torch
import numpy as np
from torch import nn
from torch import optim
from torch.autograd import Variable
import torch.nn as nn

## Defining the network: 

class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 50
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 128),
        
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(128, 256),
            
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(256, 512),
            
        )
        
        self.out = nn.Sequential(
            nn.Linear(512, n_out),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

net = GeneratorNet()
print(net)

## Some fake data to fit: 

X = np.random.rand(10,1,100)
X = Variable(torch.FloatTensor([X]), requires_grad=True)

Y = np.random.rand(10,1,50)
Y = Variable(torch.FloatTensor([Y]), requires_grad=True)

## The second dataset: 

Y2 = np.random.rand(1,5)
Y2 = Variable(torch.FloatTensor([Y2]), requires_grad=True)

## Cost function: 
cost = torch.nn.MSELoss(size_average=False)

## Optimizer: 
optimizer = optim.SGD(net.parameters(), lr=0.002)

Here is the first stage of training, where the “Y” dataset is fit. Usually works fine unless the learning rate is too high (in which case training becomes unstable).

## The first stage of training: 

## Parameters before training:
a = list(net.parameters())[-1].clone()

for epoch in range(100):
        optimizer.zero_grad()
        outputs = net(X)
        ## Fitting the network to the first dataset Y
        loss = cost(outputs, Y)
        print(epoch, loss)
        loss.backward(retain_graph=True)
        optimizer.step()
        
## Checking if the parameters were updated: 
b = list(net.parameters())[-1].clone()
torch.equal(a.data, b.data)

Then I freeze the trained layers and add a new one. Only the last layer will be trained in this stage of training. Something goes wrong here, because if I don’t freeze the layers here, the training works fine.

## Freezing the weights and adding an extra layer: 
for param in net.parameters():
    param.requires_grad = False

## Adding a new layer: 

net.fc = nn.Sequential(
            nn.Linear(50, 50),
            nn.ReLU()
            )

# add the unfrozen fc2 weight to the current optimizer
optimizer.add_param_group({'params': net.fc.parameters()})

Second stage of training: the cost does not decrease and the parameters are not updated:

## Training: stage 2
a = list(net.parameters())[-1].clone()

for epoch in range(100):
        optimizer.zero_grad()
        outputs = net(X)
        ## Sum all the outputs and normalize:
        outputs_sum = torch.sum(outputs,1)/10
        outputs_sum2 = torch.sum(outputs_sum.view((10,5)),0)
        ## Now fit to the 2nd dataset 
        loss = cost(outputs_sum2, Y2)
        print(epoch, loss)
        loss.backward(retain_graph=True)
        optimizer.step()
## Parameters do not update and the cost is not decreasing
b = list(net.parameters())[-1].clone()
torch.equal(a.data, b.data)

Any tips on what goes wrong would be highly appreciated.

Cheers!