Batchnorm for different sized samples in batch

I wrote a solution to do this fast, explained as comments in the code. Let me know if you find any bugs.

    def masked_batchnorm1d_forward(x, mask, bn):
        """x is the input tensor of shape [batch_size, n_channels, time_length]
            mask is of shape [batch_size, 1, time_length]
            bn is a BatchNorm1d object
        """
        if not self.training:
            return bn(x)

        # In each example of the batch, we can have a different number of masked elements
        # along the time axis. It would have to be represented as a jagged array.
        # However, notice that the batch and time axes are handled the same in BatchNorm1d.
        # This means, we can merge the time axis into the batch axis, and feed BatchNorm1d
        # with a tensor of shape [n_valid_timesteps_in_whole_batch, n_channels, 1],
        # as if the time axis had length 1.
        #
        # So the plan is:
        #  1. Move the time axis next to the batch axis to the second place
        #  2. Merge the batch and time axes using reshape
        #  3. Create a dummy time axis at the end with size 1.
        #  4. Select the valid time steps using the mask
        #  5. Apply BatchNorm1d to the valid time steps
        #  6. Scatter the resulting values to the corresponding positions
        #  7. Unmerge the batch and time axes
        #  8. Move the time axis to the end 

        n_feature_channels = x.shape[1]
        time_length = x.shape[2]
        reshaped = x.permute(0, 2, 1).reshape(-1, n_feature_channels, 1)
        reshaped_mask = mask.reshape(-1, 1, 1) > 0
        selected = torch.masked_select(reshaped, reshaped_mask).reshape(-1, n_feature_channels, 1)
        batchnormed = bn(selected)
        scattered = reshaped.masked_scatter(reshaped_mask, batchnormed)
        backshaped = scattered.reshape(-1, time_length, n_feature_channels).permute(0,2,1)
        return backshaped
1 Like