How to efficiently multiply a mask to hidden layers

Hi, wondering if this is possible. I’ve set up an MLP model with a variable number of hidden layers using ModuleList. I’d like to know how to efficiently apply a mask of the same dimension to these hidden layers. Here’s what I have, but I’d like to avoid iteration and stick to just matrix multiplication if possible.

class MLP(nn.Module):
    def __init__(self, input_size, num_hidden_layers, hidden_size, output_size):
        super(MLP, self).__init__()
        assert num_hidden_layers > 0
        self.input = nn.Linear(input_size, hidden_size)
        self.layers = nn.ModuleList()
        for i in range(num_hidden_layers):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
        self.output = nn.Linear(hidden_size, output_size)

def apply_mask(model, mask):
    # apply mask to hidden layer weights only
    for i in range(len(model.layers)):
        model.layers[i] = model.layers[i] * mask[i]
    return model

Are you looking for something like this? Module — PyTorch 2.1 documentation
but iteration seems hard to avoid if you’d like to use a differently specified mask for each weight

Also btw, you shouldn’t be using .data. If you’d like to update your weight in-place without autograd tracking you can use with torch.no_grad.