Hello there,
I’m trying to create a non-fully connected layer (custom connectivity layer) by implementing custom module that wraps nn.Linear
. To do this, I zero out the weights at the beggining and try to zero-out the gradients needed by using register_backward_hook
using the mask of connections I want. Based on several topics on the forums, I came up with something like this (I tried several things so currently it’s probably pretty dumb):
class MaskedLinear(nn.Module):
def _zero_grad_mask(self, module, grad_input, grad_output):
new_grad = Variable(
torch.Tensor(grad_input[2].shape).cuda().masked_scatter_(self.indices_mask_tensor, grad_input[2].data))
new_out = (grad_input[0], None, new_grad)
return new_out
def __init__(self, in_features, out_features, indices_mask):
"""
:param in_features: number of input features
:param out_features: number of output features
:param indices_mask: list of two lists containing indices for dimensions 0 and 1, used to create the mask
"""
super(MaskedLinear, self).__init__()
self.linear = nn.Linear(in_features, out_features)
indices_mask = [indices_mask[1], indices_mask[0]] # gradient and weights seem to be transposed
good_weights = self.linear.weight.data[indices_mask]
self.linear.weight.data[:] = 0 # to zero it out first
self.linear.weight.data[indices_mask] = good_weights
self.indices_mask_tensor = torch.ByteTensor(out_features, in_features).cuda()
self.indices_mask_tensor[indices_mask] = 1
self.handle = self.linear.register_backward_hook(self._zero_grad_mask) # to make sure gradients won't propagate
def forward(self, input):
return self.linear(input)
Now the problem is that when I try to run this module on GPU, I run out of memory almost instantly, most likely due to the creation of new Tensor
every backward step (when I used .clone()
instead, it ran out of memory anyway, just a bit later). When I tried to modify the existing gradient Variable, the solution did not seem to work at all, which would make sense, as the register_backward_hook
documentation says that input should not be modified.
Could anyone advise me towards a proper implementation of such module?
Another thing I don’t understand is that my grad_input
is a 3-element tuple, with 2nd element being None
and the 3rd element being what I believe is I need to modify. The register_backward_hook
implementation says that input and output could be a tuple if a module has multiple inputs/outputs, but it doesn’t seem to be the case. Why does it look like that?