Gradient masking in register_backward_hook for custom connectivity - efficient implementation

@jonathanEphrath Here is my implementation which was inspired by @jukiewiczm’s code.

class MaskedLinear(nn.Module):
    def __init__(self, in_dim, out_dim, indices_mask):
        """
       :param in_features: number of input features
       :param out_features: number of output features
       :param indices_mask: list of two lists containing indices for dimensions 0 and 1, used to create the mask
       """
        super(MaskedLinear, self).__init__()
 
        def backward_hook(grad):
            # Clone due to not being allowed to modify in-place gradients
            out = grad.clone()
            out[self.mask] = 0
            return out
 
        self.linear = nn.Linear(in_dim, out_dim).cuda()
        self.mask = torch.ones([out_dim, in_dim]).byte().cuda()
        self.mask[indices_mask] = 0 # create mask
        self.linear.weight.data[self.mask] = 0 # zero out bad weights
        self.linear.weight.register_hook(backward_hook) # hook to zero out bad gradients
 
    def forward(self, input):
        return self.linear(input)
3 Likes