I’m trying to create a non-fully connected layer (custom connectivity layer) by implementing custom module that wraps
nn.Linear. To do this, I zero out the weights at the beggining and try to zero-out the gradients needed by using
register_backward_hook using the mask of connections I want. Based on several topics on the forums, I came up with something like this (I tried several things so currently it’s probably pretty dumb):
class MaskedLinear(nn.Module): def _zero_grad_mask(self, module, grad_input, grad_output): new_grad = Variable( torch.Tensor(grad_input.shape).cuda().masked_scatter_(self.indices_mask_tensor, grad_input.data)) new_out = (grad_input, None, new_grad) return new_out def __init__(self, in_features, out_features, indices_mask): """ :param in_features: number of input features :param out_features: number of output features :param indices_mask: list of two lists containing indices for dimensions 0 and 1, used to create the mask """ super(MaskedLinear, self).__init__() self.linear = nn.Linear(in_features, out_features) indices_mask = [indices_mask, indices_mask] # gradient and weights seem to be transposed good_weights = self.linear.weight.data[indices_mask] self.linear.weight.data[:] = 0 # to zero it out first self.linear.weight.data[indices_mask] = good_weights self.indices_mask_tensor = torch.ByteTensor(out_features, in_features).cuda() self.indices_mask_tensor[indices_mask] = 1 self.handle = self.linear.register_backward_hook(self._zero_grad_mask) # to make sure gradients won't propagate def forward(self, input): return self.linear(input)
Now the problem is that when I try to run this module on GPU, I run out of memory almost instantly, most likely due to the creation of new
Tensor every backward step (when I used
.clone() instead, it ran out of memory anyway, just a bit later). When I tried to modify the existing gradient Variable, the solution did not seem to work at all, which would make sense, as the
register_backward_hook documentation says that input should not be modified.
Could anyone advise me towards a proper implementation of such module?
Another thing I don’t understand is that my
grad_input is a 3-element tuple, with 2nd element being
None and the 3rd element being what I believe is I need to modify. The
register_backward_hook implementation says that input and output could be a tuple if a module has multiple inputs/outputs, but it doesn’t seem to be the case. Why does it look like that?