@jonathanEphrath Here is my implementation which was inspired by @jukiewiczm’s code.
class MaskedLinear(nn.Module):
def __init__(self, in_dim, out_dim, indices_mask):
"""
:param in_features: number of input features
:param out_features: number of output features
:param indices_mask: list of two lists containing indices for dimensions 0 and 1, used to create the mask
"""
super(MaskedLinear, self).__init__()
def backward_hook(grad):
# Clone due to not being allowed to modify in-place gradients
out = grad.clone()
out[self.mask] = 0
return out
self.linear = nn.Linear(in_dim, out_dim).cuda()
self.mask = torch.ones([out_dim, in_dim]).byte().cuda()
self.mask[indices_mask] = 0 # create mask
self.linear.weight.data[self.mask] = 0 # zero out bad weights
self.linear.weight.register_hook(backward_hook) # hook to zero out bad gradients
def forward(self, input):
return self.linear(input)