Train NN by freezing last n layers

(Matthias Büchi) #1

Hi all

I’m currently trying to train a NN where a second network is used as kind of cost function. So I tried to chain the two networks and back-propagate the errors of the second network to the first one.


The second net should not update its weights so i have set require_grad to false and just pass the parameters of the first one to the optimizer. I also set the second network to evaluation mode so the batchnorm layers do not calculate values.

There are only linear layers with ReLU, Batchnorm and BCE loss.

Now i get a strange result of the loss over a few epochs. It increases directly and then converges to some “high” value. Did i miss something or what could be the problem?

Thanks for any help and best regards

(Simon Wang) #2

What you described should just work. Do you mind posting the script?

(Matthias Büchi) #3

I cannot post the full code, but i tried to extract the most important parts.

First net:

class Transformer(nn.Module):

    def __init__(self):
        super(Transformer, self).__init__() = nn.Sequential()'linear_0', nn.Linear(1152, 460))'activation_0', nn.ReLU())'bnorm_0', nn.BatchNorm1d(460))'linear_1', nn.Linear(460, 115))'activation_1', nn.ReLU())'bnorm_1', nn.BatchNorm1d(115))'linear_2', nn.Linear(115, 460))'activation_2', nn.ReLU())'bnorm_2', nn.BatchNorm1d(460))'linear_3', nn.Linear(460, 1152))'activation_3', nn.ReLU())'bnorm_3', nn.BatchNorm1d(1152))

    def forward(self, x):

Second net:

class Classifier(nn.Module):

    def __init__(self):
        super(Classifier, self).__init__() = nn.Sequential()'linear_0', nn.Linear(1152, 2048))'activation_0', nn.ReLU())'bnorm_0', nn.BatchNorm1d(2048))'linear_1', nn.Linear(2048, 2048))'activation_1', nn.ReLU())'bnorm_1', nn.BatchNorm1d(2048))'linear_2', nn.Linear(2048, 2048))'activation_2', nn.ReLU())'bnorm_2', nn.BatchNorm1d(2048))'linear_3', nn.Linear(2048, 2048))'activation_3', nn.ReLU())'bnorm_3', nn.BatchNorm1d(2048))'linear_4', nn.Linear(2048, 2048))'activation_4', nn.ReLU())'bnorm_4', nn.BatchNorm1d(2048))'linear_5', nn.Linear(2048, 2048))'activation_5', nn.ReLU())'bnorm_5', nn.BatchNorm1d(2048))'linear_6', nn.Linear(2048, 2040))'softmax', nn.Softmax())

    def forward(self, x):

Chaining both:

class Chain(nn.Module):
    def __init__(self, transformer_model, classification_model):
        super(Chain, self).__init__()

        self.transformer = transformer_model
        self.classifier = classification_model

    def forward(self, data):
        # compute transformer output
        output = self.transformer.forward(data)

        # compute classifier output
        output = self.classifier.forward(output)

        return output

    def train(self, mode=True): = mode


Now i create the models as follows, the second one is already trained and works as expected (standalone):

model = tr_fix.Transformer()
sc_model = sc_fix.Classifier()
full_model = state_back_fix.Chain(model, sc_model)

Set requires_grad to false:

for param in full_model.classifier.parameters():
    param.requires_grad = False

Optim and loss:

loss_func = torch.nn.BCELoss()
optimizer = optim.Adam(full_model.transformer.parameters(), lr=0.001)

Then i train that thing using a library i wrote for convenience:

trainer = candle.Trainer(full_model, optimizer,
                                 targets=[candle.Target(loss_func, loss_func)],

train_log = trainer.train(train_loader, dev_loader)
eval_log = trainer.evaluate(test_loader)

train/dev/test_loader are custom pytorch dataloaders. The important part from the Trainer you can find here:

Thanks already for confirming that i am not totally wrong.

Furthermore the loss i get looks like that:

(Aradhya Mathur) #4

Hi, I too encountered a similar problem. Did you find a workaround this ?