Defining modules in init but not using in forward method

I have a neural net wherein I define multiple nn.Sequential modules in the __init__(). However, I use only one of them in the forward method based on a condition. Still, for all these modules, the weights and biases are initialized and requires_grad is also True which is expected. I do not actually need the gradients of modules that I do not use. Example code:

class Net(nn.Module):
  def __init__(self, head_type):
        self.head_type = head_type
        self.linear_head = nn.Linear(in_features=model_dim, out_features=classifier_dim)
        self.lstm = nn.LSTM(model_dim, model_dim, batch_first=True)
        self.multilinear_head = nn.Sequential(nn.Linear(in_features=model_dim, out_features=256),
                                               nn.Linear(in_features=256, out_features=classifier_dim),

 def forward(self, x):
      if self.head_type == 'linear':
         out = self.linear_head(x)
     return x 

Is this alright with respect to space the additional tensors are taking? I guess since forward pass is not peformed for the unused modules they would be taking minimal space. Any work-around for this would be helpful. Thank you

Since some modules are not used during the forward pass, no intermediate tensors would be created. However, the parameters of these unused modules will still be pushed to the device.
If you don’t want to use these modules, you could either remove them from the __init__ or delete them afterwards:

model.multilinear_head = None

The cleaner solution would be the first approach, as the second one might have side effects (e.g. I don’t know, if forward hooks would break etc.).

Alternatively, you could also manually move the unused modules back to the CPU (which would at least save some GPU memory):


If I remove them from __init__, where do I define them? Since what head has to be chosen depends on the conditional parameter that will be passed to the __init__ i.e. the head_type. What I could perhaps do is make separate methods within the class for each head and put conditions in the __init__ itself. Any other cleaner solution for this?

I think a conditional creation of the desired “head” module in __init__ sounds like a proper way.