I wanted to implement a mask layer with a minimum change of code, and my original solution was as below.
class Mask(nn.Module):
def __init__(self, embed_dim, num_features):
super().__init__()
self.embed_dim = embed_dim
self.num_features = num_features
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.mask=None
def forward(self, x, mask=None):
assert self.mask is not None or mask is not None, "'mask' must be set before forward is called or must be passed to forward"
if mask is None:
mask=self.mask
assert mask.shape[0] == x.shape[0], f"batch_size is not equal"
assert mask.shape[1] == self.num_features, f"num_features is not equal"
assert ((mask==0) | (mask==1)).all() , "mask is not binary"
x[mask == 0]=self.mask_token
return x
# model initialization
model=Bert()
# insert mask layer in front of the module of interest
model.embedder=nn.Sequential(Mask(1500, 200), model.embedder)
# during forward pass, set the mask to use before calling `forward` of the model.
model.embedder[0].mask=torch.ones(1500,200)
y=model(x)
However, this solution becomes problematic when the module is parallelized (data parallelism). As each input to the module is split along the batch dimension, the batch dimensions of mask and x
differ from each other. The batch dimension of x
gets low, but the batch dimension of the mask remains the same. Do you have any other ideas on solving the issues?