How to mask a sequence with randomly placed blocks of random length?

I’m looking to mask the input to a sequence model with independently randomly placed blocks of random length. Here is a function prototype with pseudo-code for what I want:

def mask(x, prob, max_length, batch_dim=0, seq_dim=1, mask_value=0):
    """Returns a new tensor like x but possibly with certain elements masked to mask_value.

    More precisely,
      ret = x.clone()
      for each i < x.size(batch_dim)
        for each j < x.size(seq_dim)
          with probability prob
            l = randint(1, max_length)
            for each index tuple idx of ret
              if idx[batch_dim] == i and j <= idx[seq_dim] < j+l
                ret[idx] = mask_value
      return ret
    """
    pass

I am looking for a way to do this without actually looping through the tensor, which is at least inelegant and possibly would cause a speed issue for the model (didn’t profile it though).

I’m also open to suggestions of variant masking strategies that might work equally well but be easier to implement if this one is hard.

Your use case sound similar to torchvision.transforms.RandomErasing, which is used to randomly selects rectangle regions in an image and erases its pixels, so maybe you could reuse their approach for your temporal data.

Thanks. It looks like this code masks a single rectangle in a single image, which is easy to do.

More specifically:

Looping through the whole sequence to choose points to start blocks at is probably fine if it is done in C++, but it is probably slow in python, so I thought there might be some vectorized alternative. Another approach would be to draw the number of blocks from the binomial distribution and then choose that number of indices (not sure what function does this) and then loop through these indices, which is probably faster but still not nice to have to do in python. In either case, the actual masking can be done using array index notation so long as we arrange for the batch and time dimensions to come first (which can be done by transposing, which is pretty fast I guess).

The torch.Transforms code only works on one image at a time; it is meant for data loaders. That might be OK for my case, but ideally masking would work as a layer.