After transfer learning model restart from zero


(Nicolò Savioli) #1

Sorry if it turned the question again. But I can not understand if it is a behaviour linked to the code or to the adaptation of the network.

I have two models and I have to transfer part of their weight to a new model.
For example, let’s say that from model A I want to take only constitutionals layers, while from model B I only take FC.

Well, I apply the weight update like this:

model = model_in_code_from_autograd().cuda()
pretrain_model = torch.load(“path/…/model.pt”).cuda()
model_dict = model.state_dict()
pretrained_dict = pretrain_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Now to understand if the following code works, I do the following thing:

A. I read all the layers of the empty model where I have to transfer the information, and I calculate the average.
B. update from the model driven to the empty one, and control the average.
C. I do the same thing after updating the model.

The average of the empty model after the transfer of the weights is more or less similar to that of the previous ones. Then, the weights have been modified!

… Mean layers pretrain model is : 0.117181614

… Mean layers new-model is : 0.16264322

… Mean layer new-model after update is: 0.11407055

Here the code:

# Control mean of state parameters of my pretrain-model 
params_model       = model_v_fcnn.named_parameters() 
list_mean_layer_a = []
for name_p, param_p in params_model:
    mean = torch.mean(model_v_fcnn.state_dict()[name_p]).data.cpu().numpy()
    list_mean_layer_a.append(mean)

print("\n ... Mean layers pretrain model is : "+str(np.mean(list_mean_layer_a)))

# Control mean of state parameters of my new-model 
params_model       = model.named_parameters()
list_mean_layer_b = []
for name_p, param_p in params_model:
    mean = torch.mean(model.state_dict()[name_p]).data.cpu().numpy()
    list_mean_layer_b.append(mean)

print("\n ... Mean layers  new-model is : "+str(np.mean(list_mean_layer_b)))

# Update the first part of the model 
model_dict       = model.state_dict() 
pretrained_dict  = model_v_fcnn.state_dict()
# filter the model within a specific key
filter_model     = {}
for k, v in pretrained_dict.items():
    if k.split(".")[0][0] != "r":
        filter_model[k] = v
# 2. overwrite entries in the existing state dict  
model_dict.update(filter_model) 
# 3. load the new state dict
model.load_state_dict(model_dict)

# Re-Update the second part of the model 
model_dict       = model.state_dict() 
pretrained_dict  = model_rv_fcnn.state_dict()
# filter the model within a specific key
filter_model     = {}
for k, v in pretrained_dict.items():
    if k.split(".")[0][0] == "r":
        filter_model[k] = v
# 2. overwrite entries in the existing state dict  
model_dict.update(filter_model) 
# 3. load the new state dict
model.load_state_dict(model_dict)

# Checking if the previus new-model mean is different 
params_model       = model.named_parameters() 
list_mean_layer_c  = []
for name_p, param_p in params_model:
    mean = torch.mean(model.state_dict()[name_p]).data.cpu().numpy()
    list_mean_layer_c.append(mean)

print("\n ... Mean layer new-model after update is: "+str(np.mean(list_mean_layer_c)))

However, when I run the model again it seems that starts from scratch (as if the weights were random then):

A. It turns out to be a normal behavior
B. Do weights have to be readjusted?
C. Do I have any errors in the way I update the weights?

Best,

Nico


#2

I’ve checked your update code and it seems to work:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1, bias=False)
        self.fc1 = nn.Linear(6*24*24, 5, bias=False)
        
    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

    def set_weight(self, scalar):
        self.conv1.weight.data.fill_(scalar)
        self.fc1.weight.data.fill_(scalar)

    def print_weights(self):
        print('Conv weight: {}\nLinear weigh;: {}'.format(
            self.conv1.weight, self.fc1.weight))
    
modelA = MyModel()
modelA.set_weight(0.)
modelA.print_weights()

modelB = MyModel() # First "pre-trained" model
modelB.set_weight(1.)
modelB.print_weights()

modelC = MyModel() # Second "pre-trained" model
modelC.set_weight(2.)
modelC.print_weights()

# Update conv part of model
model_dictA = modelA.state_dict()
model_dictB = modelB.state_dict()

# Filter
filter_state_dict = {}
for k, v in model_dictB.items():
    if 'conv' in k:
        filter_state_dict[k] = v

# Overwrite entries in existing model_dictA
model_dictA.update(filter_state_dict)
print(model_dictA)

# Load new state dict
modelA.load_state_dict(model_dictA)
modelA.print_weights()

# Re-update the second part
model_dictA = modelA.state_dict()
model_dictC = modelC.state_dict()

# Filter
filter_state_dict = {}
for k, v in model_dictC.items():
    if 'fc' in k:
        filter_state_dict[k] = v

# Overwrite entries in existing model_dictA
model_dictA.update(filter_state_dict)

# Load new state dict
modelA.load_state_dict(model_dictA)
modelA.print_weights()

# Run forward pass
x = torch.ones(1, 3, 24, 24)
output = modelA(x)

I cannot see any differences.
Could you explain a bit, what you mean by “it seems to start from scratch”?
Is the performance bad?
You are copying the weights from two different models, so I would assume the new model needs some training, because the different parts weren’t trained together.


(Nicolò Savioli) #3

Hey,

Thanks for your kind reply, I thought about this problem all day. And you kindly confirmed to me that they are not code errors!

My network uses 3D kernels so I probably have many parameters to optimise, a fundamental role is played by the lr. Analysing the behaviour of the gradient after transfer learning on two different parts of the model using a SGD, with momentum = 0.9, weight_decay = 1e-3, nesterov = True.

I have the following behaviors to vary the learning rate.

lr = 1e-3 the model does not optimize
lr = 1e-2 begins to rise with accuracy.

Now with 1e-2 after transfer learning I have the following values ​​(where the maximum is 100)

epochs

0.0
18.02626847989203
19.78656745645463
14.658616073090036
17.290238624269023
18.621533335401455
17.22693140700638
27.07915645534009
13.420461302124785

As you can see, the network at the second epoch already starts at 18% accuracy, but with a lr = 1e-3 I have the following behavior:

0.0
0.0
0.0
0.0
0.0

Which turns out to be comparable to the network without transfers.

So the question is spontaneous to me. Is there a correlation between the speed of parameters updating and the re-adapting of the network?


#4

I assume the number represent the accuracy?
Do you have a classification use case? If so, it’s a bit strange to see 0% accuracy, because that would be worse than random.

I’ve never heard of an approach to combine two pre-trained networks together in different parts of the model, but it sounds interesting.
Have you tried to fine-tune the last part first, while freezing the first part? Then maybe after a while you could train the model end-to-end.


(Nicolò Savioli) #5

Yes Accuracy, my network is a semantic network for segmentation more then classification. But I think transfer learning is a key for creating more good generalisation.

The problem is why lr is so sensible… I need to investigate a bit more :slight_smile:


#6

But how come your accuracy is zero the whole time when training with lr=1e-3?
Even if your model predicts all zeros (e.g. background), there should be some background in your targets, so that the accuracy would be higher. Am I missing something in your use case?
Maybe there is another issue regarding the training / accuracy calculation?


(Nicolò Savioli) #7

I can not explain it, if I use an adaptive method like adam it works, with SDG I need to stay under 1e-3.
The metric that I use in this project is the DICE index https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient


Pre-training problem in pytorch
#8

Ok, that makes sense. Thanks for the clarification :wink: