Quick way to freeze weights?

EDIT: This only seems to be happening on CPU.

Hi,

Is there a quick way to freeze and unfreeze the weights of a network?

Currently I have the two functions to freeze and unfreeze the weights

def freeze_model(model):
    model.eval()
    for params in model.parameters():
        params.requires_grad = False
        
def unfreeze_model(model):
    model.train()
    for params in model.parameters():
        params.requires_grad = True

I use them in my training code below

# fix adversary take gradient step with classifier and encoder
unfreeze_model(encoder)
unfreeze_model(classifier)
z = encoder(x_batch)
y_hat = classifier(z)

freeze_model(adversary)
a_fixed = adversary(z)

opt_cls.zero_grad()
opt_en.zero_grad()

cls_loss = cls_criterion(y_hat, y_cls_batch)
adv_loss_fixed = adv_criterion(a_fixed, y_adv_batch, y_cls_batch)
cls_en_combinedLoss = cls_loss + adv_loss_fixed
cls_en_combinedLoss.backward()
opt_cls.step()
opt_en.step()        

# fix encoder and classifier and take gradient step with adversary
freeze_model(encoder)
freeze_model(classifier)
z_fixed = encoder(x_batch)
y_hat_fixed = classifier(z_fixed)

adversary.train()
a_hat = adversary(z_fixed)

opt_adv.zero_grad()

cls_loss_fixed = cls_criterion(y_hat_fixed, y_cls_batch)
adv_loss = adv_criterion(a_hat, y_adv_batch, y_cls_batch)

adv_combinedLoss = -(cls_loss_fixed + adv_loss)
adv_combinedLoss.backward()

opt_adv.step()

However, due to this freezing and unfreezing for every mini-batch it takes a longer time. Any thoughts? Thank you!

Do you see an unusual long duration for freezing / unfreezing the models?
If freezing and unfreezing takes more time than just calculating the gradients, you could just remove it and clear the gradients corresponding to your use case.

As a side note: Did you forget to unfreeze the adversary at the end of your code?