Randomised freezing of layers while training

Suppose my model has 3 convolutional layers where I want to randomly train all or only one of the layers during forward pass.

class myConv(nn.Module):
    self.conv1 = nn.conv2d()
    self.conv2 = nn.conv2d()
    self.conv3 = nn.conv2d()
def forward(x):
    z = torch.randint(0,4,(1,))
    if z == 0:
       out = self.conv1(x)*self.conv2(x)*self.conv3(x) 
    elif z == 1:      
       out = self.conv1(x)
    elif z == 2:      
       out = self.conv2(x)
       out = self.conv3(x)                
    return out

How, do I correctly freeze the weights of the conv layers so that I can train this kind of a layer.
I would prefer ways that can be coded inside the forward pass rather than in training loop since ‘myConv’ is a part of much larger model.
Note that at test time value of ‘z’ will be fixed so I am concerned about training phase only.

Any other suggestions are also welcomed.
Thanks in advanced.

In the branches where layers are not used you might not need to freeze their parameters since Autograd won’t compute gradients for them.
Since these parameters were not used in the forward pass they were never added to the computation graph and are thus irrelevant in the backward pass.
If you explicitly want to freeze these parameters nevertheless, setting their .requires_grad attribute to False is the right approach.

Also, note that optimizers with running internal states might still update frozen parameters even with a zero gradient. To avoid it you should delete their .grad attribute or use .zero_grad(set_to_none=True).

1 Like

Thanks for the confirmation. The model was indeed training fine but I just wanted a validation if I am doing it right or is their a better efficient way. I found .requires_grad True/False as desired needs to be set explicitly within each if-else condition otherwise the state from previous batch will apply to next batch and the model will not learn.

Your pointer about .zero_grad(set_to_none=True) is something new to me and a nice way to save compute.