Train NN by freezing last n layers


(Matthias Büchi) #1

Hi all

I’m currently trying to train a NN where a second network is used as kind of cost function. So I tried to chain the two networks and back-propagate the errors of the second network to the first one.

sc_combine

The second net should not update its weights so i have set require_grad to false and just pass the parameters of the first one to the optimizer. I also set the second network to evaluation mode so the batchnorm layers do not calculate values.

There are only linear layers with ReLU, Batchnorm and BCE loss.

Now i get a strange result of the loss over a few epochs. It increases directly and then converges to some “high” value. Did i miss something or what could be the problem?

Thanks for any help and best regards
Matthias


(Simon Wang) #2

What you described should just work. Do you mind posting the script?


(Matthias Büchi) #3

I cannot post the full code, but i tried to extract the most important parts.

First net:

class Transformer(nn.Module):

    def __init__(self):
        super(Transformer, self).__init__()

        self.net = nn.Sequential()

        self.net.add_module('linear_0', nn.Linear(1152, 460))
        self.net.add_module('activation_0', nn.ReLU())
        self.net.add_module('bnorm_0', nn.BatchNorm1d(460))

        self.net.add_module('linear_1', nn.Linear(460, 115))
        self.net.add_module('activation_1', nn.ReLU())
        self.net.add_module('bnorm_1', nn.BatchNorm1d(115))

        self.net.add_module('linear_2', nn.Linear(115, 460))
        self.net.add_module('activation_2', nn.ReLU())
        self.net.add_module('bnorm_2', nn.BatchNorm1d(460))

        self.net.add_module('linear_3', nn.Linear(460, 1152))
        self.net.add_module('activation_3', nn.ReLU())
        self.net.add_module('bnorm_3', nn.BatchNorm1d(1152))

    def forward(self, x):
        return self.net(x)

Second net:

class Classifier(nn.Module):

    def __init__(self):
        super(Classifier, self).__init__()

        self.net = nn.Sequential()

        self.net.add_module('linear_0', nn.Linear(1152, 2048))
        self.net.add_module('activation_0', nn.ReLU())
        self.net.add_module('bnorm_0', nn.BatchNorm1d(2048))

        self.net.add_module('linear_1', nn.Linear(2048, 2048))
        self.net.add_module('activation_1', nn.ReLU())
        self.net.add_module('bnorm_1', nn.BatchNorm1d(2048))

        self.net.add_module('linear_2', nn.Linear(2048, 2048))
        self.net.add_module('activation_2', nn.ReLU())
        self.net.add_module('bnorm_2', nn.BatchNorm1d(2048))

        self.net.add_module('linear_3', nn.Linear(2048, 2048))
        self.net.add_module('activation_3', nn.ReLU())
        self.net.add_module('bnorm_3', nn.BatchNorm1d(2048))

        self.net.add_module('linear_4', nn.Linear(2048, 2048))
        self.net.add_module('activation_4', nn.ReLU())
        self.net.add_module('bnorm_4', nn.BatchNorm1d(2048))

        self.net.add_module('linear_5', nn.Linear(2048, 2048))
        self.net.add_module('activation_5', nn.ReLU())
        self.net.add_module('bnorm_5', nn.BatchNorm1d(2048))

        self.net.add_module('linear_6', nn.Linear(2048, 2040))
        self.net.add_module('softmax', nn.Softmax())

    def forward(self, x):
        return self.net(x)

Chaining both:

class Chain(nn.Module):
    def __init__(self, transformer_model, classification_model):
        super(Chain, self).__init__()

        self.transformer = transformer_model
        self.classifier = classification_model

    def forward(self, data):
        # compute transformer output
        output = self.transformer.forward(data)

        # compute classifier output
        output = self.classifier.forward(output)

        return output

    def train(self, mode=True):
        self.training = mode

        self.transformer.train(mode=mode)
        self.classifier.train(mode=False)

Now i create the models as follows, the second one is already trained and works as expected (standalone):

model = tr_fix.Transformer()
sc_model = sc_fix.Classifier()
sc_model.load_state_dict(torch.load(sc_model_state_path))
full_model = state_back_fix.Chain(model, sc_model)

Set requires_grad to false:

for param in full_model.classifier.parameters():
    param.requires_grad = False

Optim and loss:

loss_func = torch.nn.BCELoss()
optimizer = optim.Adam(full_model.transformer.parameters(), lr=0.001)

Then i train that thing using a library i wrote for convenience:

trainer = candle.Trainer(full_model, optimizer,
                                 targets=[candle.Target(loss_func, loss_func)],
                                 num_epochs=10,
                                 use_cuda=True)

train_log = trainer.train(train_loader, dev_loader)
eval_log = trainer.evaluate(test_loader)

train/dev/test_loader are custom pytorch dataloaders. The important part from the Trainer you can find here:

Thanks already for confirming that i am not totally wrong.

Furthermore the loss i get looks like that:
bce


(Aradhya Mathur) #4

Hi, I too encountered a similar problem. Did you find a workaround this ?