I have a neural network in PyTorch that has been trained on a dataset and the model is saved as saved_model.pth.tar
.
The current loss of the model is a simple MSE loss loss = torch.nn.MSELoss(target,predictions)
.
Now I am training the network on a new dataset and I want to regularise the model against the saved model (saved_model.pth.tar
) so that the model is penalised if the weights of the neural network are far away from the saved model.
In essence what I want to do is:
regularisation_loss = (existing_weights - saved_weights)**2
loss = torch.nn.MSELoss(target,predictions) + regularisation loss
The neural network is quite complex and so I wanted to ask if there is a general way of enforcing regularisation against a saved model.