How to implement permanent weight dropout in a layer?

Hi! I am trying to implement a weight dropout mechanism, that drops weights from a feed forward layer and keeps them zero during training. I am wondering, what is the best way to do this?

class SparseLinear(torch.nn.Module):
    """
    Linear layer that uses sparse dense weights.
    """

    def __init__(self, in_features, out_features, sparsity):
        super().__init__()

        w = torch.empty(out_features, in_features)
        torch.nn.init.sparse_(w, sparsity=sparsity)
        self.w = torch.nn.Parameter(w)

        b = torch.empty(1, out_features)
        torch.nn.init.sparse_(b, sparsity=sparsity)
        self.b = torch.nn.Parameter(b)

        self.w_mask = self.w == 0
        self.b_mask = self.b == 0

    def forward(self, inputs):
        return torch.nn.functional.linear(
            inputs,
            torch.masked_fill(self.w, self.w_mask, 0.0),
            bias=torch.masked_fill(self.b, self.b_mask, 0.0))

This is my current solution, however I don’t like that it performs masking on every forward call, I would rather prefer a solution that prevent the optimizer from modifying them in the first place.

I think you would have to mask these parameters at some point (either directly or their gradients to prevent an update).
The optimizer cannot take parts of tensors, so I’m not aware of any other clean approach. :confused: