I try to build my own class of nn.Linear
called MaskedLinear
that is supposed to set some weights to 0 and keep the others. The input is a matrix of shape (1024, 1024) and the masks is the same size matrix full of ones and zeros (the first n columns of the matrix = 1 and the last columns are full of 0).
Here is my implementation so far:
class MaskedLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
def forward(self, input, masks):
# `input` is of shape (batch_size, 1024, 1024)
masks = masks.view(-1, 1, 1024*1024)
# Expand the tensor to enable multiplication with self.weight
masks = masks.expand(-1, self.weight.size(0), 1024*1024)
print("input ", input.shape)
print("masks ", masks.shape)
print("self.weight ", self.weight.shape)
return F.linear(input, masks * self.weight, self.bias)
And here is the error I get:
Cell In [36], line 14, in MaskedLinear.forward(self, input, masks)
12 print("masks ", masks.shape)
13 print("self.weight ", self.weight.shape)
---> 14 return F.linear(input, masks * self.weight, self.bias)
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
The corresponding layer is:
MaskedLinear(1024*1024, 128)
When I print the shapes of input
, masks
and self.weight it looks like:
input torch.Size([8, 1048576])
masks torch.Size([8, 128, 1048576])
self.weight torch.Size([128, 1048576])
First, I don’t understand why the batch dimension is not managed.
Second question : is it the right way to mask the weights ?