EDIT: This only seems to be happening on CPU.
Hi,
Is there a quick way to freeze and unfreeze the weights of a network?
Currently I have the two functions to freeze and unfreeze the weights
def freeze_model(model):
model.eval()
for params in model.parameters():
params.requires_grad = False
def unfreeze_model(model):
model.train()
for params in model.parameters():
params.requires_grad = True
I use them in my training code below
# fix adversary take gradient step with classifier and encoder
unfreeze_model(encoder)
unfreeze_model(classifier)
z = encoder(x_batch)
y_hat = classifier(z)
freeze_model(adversary)
a_fixed = adversary(z)
opt_cls.zero_grad()
opt_en.zero_grad()
cls_loss = cls_criterion(y_hat, y_cls_batch)
adv_loss_fixed = adv_criterion(a_fixed, y_adv_batch, y_cls_batch)
cls_en_combinedLoss = cls_loss + adv_loss_fixed
cls_en_combinedLoss.backward()
opt_cls.step()
opt_en.step()
# fix encoder and classifier and take gradient step with adversary
freeze_model(encoder)
freeze_model(classifier)
z_fixed = encoder(x_batch)
y_hat_fixed = classifier(z_fixed)
adversary.train()
a_hat = adversary(z_fixed)
opt_adv.zero_grad()
cls_loss_fixed = cls_criterion(y_hat_fixed, y_cls_batch)
adv_loss = adv_criterion(a_hat, y_adv_batch, y_cls_batch)
adv_combinedLoss = -(cls_loss_fixed + adv_loss)
adv_combinedLoss.backward()
opt_adv.step()
However, due to this freezing and unfreezing for every mini-batch it takes a longer time. Any thoughts? Thank you!