Implementing a mask layer with a minimum change of code

I wanted to implement a mask layer with a minimum change of code, and my original solution was as below.

class Mask(nn.Module):
    def __init__(self, embed_dim, num_features):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_features = num_features

        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))

        self.mask=None

    def forward(self, x, mask=None):
        assert self.mask is not None or mask is not None, "'mask' must be set before forward is called or must be passed to forward"
        if mask is None:
            mask=self.mask
        assert mask.shape[0] == x.shape[0], f"batch_size is not equal"
        assert mask.shape[1] == self.num_features, f"num_features is not equal"
        assert ((mask==0) | (mask==1)).all() , "mask is not binary"
        x[mask == 0]=self.mask_token
        return x

# model initialization
model=Bert()
# insert mask layer in front of the module of interest
model.embedder=nn.Sequential(Mask(1500, 200), model.embedder)
# during forward pass, set the mask to use before calling `forward` of the model.
model.embedder[0].mask=torch.ones(1500,200)
y=model(x)

However, this solution becomes problematic when the module is parallelized (data parallelism). As each input to the module is split along the batch dimension, the batch dimensions of mask and x differ from each other. The batch dimension of x gets low, but the batch dimension of the mask remains the same. Do you have any other ideas on solving the issues?