Hi! I am trying to implement a weight dropout mechanism, that drops weights from a feed forward layer and keeps them zero during training. I am wondering, what is the best way to do this?
class SparseLinear(torch.nn.Module): """ Linear layer that uses sparse dense weights. """ def __init__(self, in_features, out_features, sparsity): super().__init__() w = torch.empty(out_features, in_features) torch.nn.init.sparse_(w, sparsity=sparsity) self.w = torch.nn.Parameter(w) b = torch.empty(1, out_features) torch.nn.init.sparse_(b, sparsity=sparsity) self.b = torch.nn.Parameter(b) self.w_mask = self.w == 0 self.b_mask = self.b == 0 def forward(self, inputs): return torch.nn.functional.linear( inputs, torch.masked_fill(self.w, self.w_mask, 0.0), bias=torch.masked_fill(self.b, self.b_mask, 0.0))
This is my current solution, however I don’t like that it performs masking on every forward call, I would rather prefer a solution that prevent the optimizer from modifying them in the first place.