Hi! I am trying to implement a weight dropout mechanism, that drops weights from a feed forward layer and keeps them zero during training. I am wondering, what is the best way to do this?
class SparseLinear(torch.nn.Module):
"""
Linear layer that uses sparse dense weights.
"""
def __init__(self, in_features, out_features, sparsity):
super().__init__()
w = torch.empty(out_features, in_features)
torch.nn.init.sparse_(w, sparsity=sparsity)
self.w = torch.nn.Parameter(w)
b = torch.empty(1, out_features)
torch.nn.init.sparse_(b, sparsity=sparsity)
self.b = torch.nn.Parameter(b)
self.w_mask = self.w == 0
self.b_mask = self.b == 0
def forward(self, inputs):
return torch.nn.functional.linear(
inputs,
torch.masked_fill(self.w, self.w_mask, 0.0),
bias=torch.masked_fill(self.b, self.b_mask, 0.0))
This is my current solution, however I don’t like that it performs masking on every forward call, I would rather prefer a solution that prevent the optimizer from modifying them in the first place.