Implementing scriptable channel masking transform for multi-channel input

Hi PyTorch community,

I’m trying to implement a channel masking transform that runs on GPU for multi-channel spectrogram input looks something like this.

class ChannelMasking(torch.nn.Module):
    def __init__(self):
        super(ChannelMasking, self).__init__()

    def forward(self, specgrams, mask_param):

        # specgrams: 4-D spectrogram (batch, channel, freq, time)
        # mask_param: Number of channels to be masked will be uniformly sampled from [0, mask_param]
        # say batch = 4, channel = 3, freq = 64, time = 16, mask_param = 2

        # trying to build (batch, channel) shaped mask. 
        mask = [
            [True, False, True], # spectro-temporal values in channel 0, 2 will be zero from batch 0
            [False, False, True], # spectro-temporal values in channel 2 will be zero from batch 1
            [False, False, False], # nothing will be changed from batch 2
            [False, False, False]  # nothing will be changed from batch 3
        ]
        mask = mask[..., None, None]

        specgrams = specgrams.masked_fill(mask, 0)
        return specgrams

I’m stuck with building a mask since it’s tricky to choose the numbers of channels to be removed batch-independently without using for-loop.

I tried to make one with torchaudio.functional.make_along_axis_iid but felt inappropriate since this function always masks adjacent bins (I don’t want mask consecutive channels all the time) and the chance of masking bins in the corner(first and last channel) is extremely low.

Thank you in advance!