To pass Different mask to each batch element in nn.TransformerEncoder's forward function

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x.permute(1,0,2)
        x = x + self.pe[:x.size(0)]
        return self.dropout(x).permute(1,0,2)
    
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.pos = PositionalEncoding(256, max_len=1000) # sequence length
        self.cls1 = nn.Parameter(torch.rand(1, 256)) # 1, enc_dim
        self.cls2 = nn.Parameter(torch.rand(1, 256))  

    def forward(self, x):
        x = self.pos(x)  
        M = ...  # defined on some logic and need to have dimension as (4, 9, 9) as stated below
        x = self.transformer_encoder(x, mask=M.logical_not(), src_key_padding_mask=None)#pad.logical_not())
        
        return x

Say batch size is 4.
For some reason I am trying to apply different mask for each input, i.e 4 mask for 4 input.

One such example of the mask looks like (assume sequence length as 9, so 9 by 9 matrix for mask)(some kind of restrictive self attention)

tensor([[1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]])

When I try to pass 4 such different mask of dimension (4, 9,9), I get error as
RuntimeError: The shape of the 3D attn_mask is torch.Size([4, 9, 9]), but should be (32, 9, 9).

What should I do to achieve what I am expecting.