Custom connections in neural network layers

If you have a mask already then you could do an element-wise multiply between the mask and the weights for every forward I would think? I don’t imagine an element-wise multiply would be too slow to do every forward. Alternatively, you could do something like this:

def zero_grad(self, grad_input, grad_output):
    return grad_input * self.mask

class MaskedLinear(nn.Module):
    def __init__(self, in_features, out_features, mask):
        super(MaskedLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.linear.weights *= mask  # to zero it out first
        self.mask = mask
        self.handle = self.register_backward_hook(zero_grad)  # to make sure gradients won't propagate

Then that way you can still have a fast forward with minimal overhead on the backward?

5 Likes