Masked nn.Linear

I try to build my own class of nn.Linear called MaskedLinear that is supposed to set some weights to 0 and keep the others. The input is a matrix of shape (1024, 1024) and the masks is the same size matrix full of ones and zeros (the first n columns of the matrix = 1 and the last columns are full of 0).

Here is my implementation so far:

class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        
    def forward(self, input, masks):
        # `input` is of shape (batch_size, 1024, 1024)
        masks = masks.view(-1, 1, 1024*1024)
        # Expand the tensor to enable multiplication with self.weight
        masks = masks.expand(-1, self.weight.size(0), 1024*1024)
        print("input ", input.shape)
        print("masks ", masks.shape)
        print("self.weight ", self.weight.shape)
        return F.linear(input, masks * self.weight, self.bias)

And here is the error I get:

Cell In [36], line 14, in MaskedLinear.forward(self, input, masks)
     12 print("masks ", masks.shape)
     13 print("self.weight ", self.weight.shape)
---> 14 return F.linear(input, masks * self.weight, self.bias)

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D

The corresponding layer is:

MaskedLinear(1024*1024, 128)

When I print the shapes of input, masks and self.weight it looks like:

input  torch.Size([8, 1048576])
masks  torch.Size([8, 128, 1048576])
self.weight  torch.Size([128, 1048576])

First, I don’t understand why the batch dimension is not managed.
Second question : is it the right way to mask the weights ?

So do you want to mask the weight or the input?
So the weight is necessarily 2d in F.linear, but you could use matmul or einsum or whatever fits your tastes (plus adding the bias) to replace linear.

Best regards

Thomas

I don’t want the network to learn the padded sections of the matrix (c.f the screenshot of 2 partial examples below: true images are up, below are the predicted ones. Dark blue represents padding, so is “false” masking).
I use a custom MaskedMSE loss function, where I can pass my masking so that the network does not backpropagate errors on padded sections. The missing thing is to apply the masking in the layers of the network (only nn.Linear).

I am not very familiar with matmul and einsum…

I cannot comment on the design, but so if the input (not the weight) is masked to 0 in the forward pass, the gradient contribution to the weight that is applied to that masked part should should be 0 (because the backward computation matrix multiplies the gradient of the output with the input (in the right order, with the right transpositions) to get the weight gradient?
If you have NaNs/infs that won’t work, though and you need to deal with those.

Best regards

Thomas

1 Like

Sorry for the late reply.
Indeed, I finally am able to mask the input in the forward pass so it is not necessary to do anything particular at the loss calculation for the reason you explained.
Thanks for the explanations !