I wanted to implement a mask layer with a minimum change of code, and my original solution was as below.
class Mask(nn.Module): def __init__(self, embed_dim, num_features): super().__init__() self.embed_dim = embed_dim self.num_features = num_features self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) self.mask=None def forward(self, x, mask=None): assert self.mask is not None or mask is not None, "'mask' must be set before forward is called or must be passed to forward" if mask is None: mask=self.mask assert mask.shape == x.shape, f"batch_size is not equal" assert mask.shape == self.num_features, f"num_features is not equal" assert ((mask==0) | (mask==1)).all() , "mask is not binary" x[mask == 0]=self.mask_token return x # model initialization model=Bert() # insert mask layer in front of the module of interest model.embedder=nn.Sequential(Mask(1500, 200), model.embedder) # during forward pass, set the mask to use before calling `forward` of the model. model.embedder.mask=torch.ones(1500,200) y=model(x)
However, this solution becomes problematic when the module is parallelized (data parallelism). As each input to the module is split along the batch dimension, the batch dimensions of mask and
x differ from each other. The batch dimension of
x gets low, but the batch dimension of the mask remains the same. Do you have any other ideas on solving the issues?