Implementing a mask layer with a minimum change of code

I wanted to implement a mask layer with a minimum change of code, and my original solution was as below.

class Mask(nn.Module):
    def __init__(self, embed_dim, num_features):
        self.embed_dim = embed_dim
        self.num_features = num_features

        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))


    def forward(self, x, mask=None):
        assert self.mask is not None or mask is not None, "'mask' must be set before forward is called or must be passed to forward"
        if mask is None:
        assert mask.shape[0] == x.shape[0], f"batch_size is not equal"
        assert mask.shape[1] == self.num_features, f"num_features is not equal"
        assert ((mask==0) | (mask==1)).all() , "mask is not binary"
        x[mask == 0]=self.mask_token
        return x

# model initialization
# insert mask layer in front of the module of interest
model.embedder=nn.Sequential(Mask(1500, 200), model.embedder)
# during forward pass, set the mask to use before calling `forward` of the model.

However, this solution becomes problematic when the module is parallelized (data parallelism). As each input to the module is split along the batch dimension, the batch dimensions of mask and x differ from each other. The batch dimension of x gets low, but the batch dimension of the mask remains the same. Do you have any other ideas on solving the issues?