I wrote a solution to do this fast, explained as comments in the code. Let me know if you find any bugs.
def masked_batchnorm1d_forward(x, mask, bn):
"""x is the input tensor of shape [batch_size, n_channels, time_length]
mask is of shape [batch_size, 1, time_length]
bn is a BatchNorm1d object
"""
if not self.training:
return bn(x)
# In each example of the batch, we can have a different number of masked elements
# along the time axis. It would have to be represented as a jagged array.
# However, notice that the batch and time axes are handled the same in BatchNorm1d.
# This means, we can merge the time axis into the batch axis, and feed BatchNorm1d
# with a tensor of shape [n_valid_timesteps_in_whole_batch, n_channels, 1],
# as if the time axis had length 1.
#
# So the plan is:
# 1. Move the time axis next to the batch axis to the second place
# 2. Merge the batch and time axes using reshape
# 3. Create a dummy time axis at the end with size 1.
# 4. Select the valid time steps using the mask
# 5. Apply BatchNorm1d to the valid time steps
# 6. Scatter the resulting values to the corresponding positions
# 7. Unmerge the batch and time axes
# 8. Move the time axis to the end
n_feature_channels = x.shape[1]
time_length = x.shape[2]
reshaped = x.permute(0, 2, 1).reshape(-1, n_feature_channels, 1)
reshaped_mask = mask.reshape(-1, 1, 1) > 0
selected = torch.masked_select(reshaped, reshaped_mask).reshape(-1, n_feature_channels, 1)
batchnormed = bn(selected)
scattered = reshaped.masked_scatter(reshaped_mask, batchnormed)
backshaped = scattered.reshape(-1, time_length, n_feature_channels).permute(0,2,1)
return backshaped