I try to build my own class of `nn.Linear`

called `MaskedLinear`

that is supposed to set some weights to 0 and keep the others. The input is a matrix of shape (1024, 1024) and the masks is the same size matrix full of ones and zeros (the first n columns of the matrix = 1 and the last columns are full of 0).

Here is my implementation so far:

```
class MaskedLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
def forward(self, input, masks):
# `input` is of shape (batch_size, 1024, 1024)
masks = masks.view(-1, 1, 1024*1024)
# Expand the tensor to enable multiplication with self.weight
masks = masks.expand(-1, self.weight.size(0), 1024*1024)
print("input ", input.shape)
print("masks ", masks.shape)
print("self.weight ", self.weight.shape)
return F.linear(input, masks * self.weight, self.bias)
```

And here is the error I get:

```
Cell In [36], line 14, in MaskedLinear.forward(self, input, masks)
12 print("masks ", masks.shape)
13 print("self.weight ", self.weight.shape)
---> 14 return F.linear(input, masks * self.weight, self.bias)
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
```

The corresponding layer is:

```
MaskedLinear(1024*1024, 128)
```

When I print the shapes of `input`

, `masks`

and self.weight it looks like:

```
input torch.Size([8, 1048576])
masks torch.Size([8, 128, 1048576])
self.weight torch.Size([128, 1048576])
```

First, I donâ€™t understand why the batch dimension is not managed.

Second question : is it the right way to mask the weights ?