Gradient masking in register_backward_hook for custom connectivity - efficient implementation

Hello there,
I’m trying to create a non-fully connected layer (custom connectivity layer) by implementing custom module that wraps nn.Linear. To do this, I zero out the weights at the beggining and try to zero-out the gradients needed by using register_backward_hook using the mask of connections I want. Based on several topics on the forums, I came up with something like this (I tried several things so currently it’s probably pretty dumb):

class MaskedLinear(nn.Module):
    def _zero_grad_mask(self, module, grad_input, grad_output):
        new_grad = Variable(
            torch.Tensor(grad_input[2].shape).cuda().masked_scatter_(self.indices_mask_tensor, grad_input[2].data))
        new_out = (grad_input[0], None, new_grad)
        return new_out

    def __init__(self, in_features, out_features, indices_mask):
        """
        :param in_features: number of input features
        :param out_features: number of output features
        :param indices_mask: list of two lists containing indices for dimensions 0 and 1, used to create the mask
        """
        super(MaskedLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        indices_mask = [indices_mask[1], indices_mask[0]]  # gradient and weights seem to be transposed
        good_weights = self.linear.weight.data[indices_mask]
        self.linear.weight.data[:] = 0  # to zero it out first
        self.linear.weight.data[indices_mask] = good_weights
        self.indices_mask_tensor = torch.ByteTensor(out_features, in_features).cuda()
        self.indices_mask_tensor[indices_mask] = 1
        self.handle = self.linear.register_backward_hook(self._zero_grad_mask)  # to make sure gradients won't propagate

    def forward(self, input):
        return self.linear(input)

Now the problem is that when I try to run this module on GPU, I run out of memory almost instantly, most likely due to the creation of new Tensor every backward step (when I used .clone() instead, it ran out of memory anyway, just a bit later). When I tried to modify the existing gradient Variable, the solution did not seem to work at all, which would make sense, as the register_backward_hook documentation says that input should not be modified.

Could anyone advise me towards a proper implementation of such module?

Another thing I don’t understand is that my grad_input is a 3-element tuple, with 2nd element being None and the 3rd element being what I believe is I need to modify. The register_backward_hook implementation says that input and output could be a tuple if a module has multiple inputs/outputs, but it doesn’t seem to be the case. Why does it look like that?

I’ll let myself call the guys from the similar topic, hopefully that’s fine.
@theQmech, @bzcheeseman, @shirindora, @mariob6, did you have any luck for implementing this? Could you provide some hints?
Thank you in advance.

Hi,
Did you manage to build a custom net that runs efficiently?
I’m trying to do the same.

Hello,
Unfortunately, I did not. I gave up the idea. I have a contact with a person who’s trying to do the same, though, so when I get to know more I will let you know (or he will).

@jonathanEphrath Here is my implementation which was inspired by @jukiewiczm’s code.

class MaskedLinear(nn.Module):
    def __init__(self, in_dim, out_dim, indices_mask):
        """
       :param in_features: number of input features
       :param out_features: number of output features
       :param indices_mask: list of two lists containing indices for dimensions 0 and 1, used to create the mask
       """
        super(MaskedLinear, self).__init__()
 
        def backward_hook(grad):
            # Clone due to not being allowed to modify in-place gradients
            out = grad.clone()
            out[self.mask] = 0
            return out
 
        self.linear = nn.Linear(in_dim, out_dim).cuda()
        self.mask = torch.ones([out_dim, in_dim]).byte().cuda()
        self.mask[indices_mask] = 0 # create mask
        self.linear.weight.data[self.mask] = 0 # zero out bad weights
        self.linear.weight.register_hook(backward_hook) # hook to zero out bad gradients
 
    def forward(self, input):
        return self.linear(input)
3 Likes

@jamesproud I tested it out and it looks like you don’t need to mask the gradients. If you mask the weights in every forward pass, then those gradient values will not have an affect on the model learned.

Also it saves computing time since masking the grads takes a bit of time.

1 Like

@Kale-ab_Tessera, I agree with you that this should suffice. How do you modify the forward pass accordingly?