If you have a mask already then you could do an element-wise multiply between the mask and the weights for every forward I would think? I don’t imagine an element-wise multiply would be too slow to do every forward. Alternatively, you could do something like this:
def zero_grad(self, grad_input, grad_output):
return grad_input * self.mask
class MaskedLinear(nn.Module):
def __init__(self, in_features, out_features, mask):
super(MaskedLinear, self).__init__()
self.linear = nn.Linear(in_features, out_features)
self.linear.weights *= mask # to zero it out first
self.mask = mask
self.handle = self.register_backward_hook(zero_grad) # to make sure gradients won't propagate
Then that way you can still have a fast forward with minimal overhead on the backward?