Speeding up mask creation

Hey everybody!

I got a newbie question: I am creating an attention mask for an NLP project. I essentially need a quadratic mask of (words x words) for each document (all documents have the same number of words) but with a twist: I want to have n different masks per word category. So essentially, I have a boolean mask of shape (Batch x Categories x Words x Words) at hand and want to efficiently select only those words which belong to a specific category. I have a minimal solution to my problem, but it uses 2 nested loops, which is very slow and I wondered if some of you more versed at Tensor broadcasting could help me out.

This is a minimal example with 2 documents, containing 4 words belonging to 3 categories.

# Batch x Category Mapping per Word. Essentially, the first word of document 0 belongs to category 1 etc.
categories = torch.LongTensor([
    [1, 0, 1, 2],
    [0, 2, 0, 0],
])

# Values to be masked: Batch x Categories X Word X Word
values = torch.rand(2, 3, 4, 4)

# Mask: Batch x Categories X Word X Word
mask = torch.zeros(2, 3, 4, 4)
num_documents = mask.size(0)
num_categories = mask.size(1)
num_tokens = mask.size(2)

for document in torch.arange(num_documents):
    for category in torch.arange(num_categories):
        # Filter only values of tokens in category.
        token_in_category = (categories[document] == category)
        mask[document, category, token_in_category] = token_in_category.float()

mask

Desired Output:

tensor([[[[0., 0., 0., 0.],
          [0., 1., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[1., 0., 1., 0.],
          [0., 0., 0., 0.],
          [1., 0., 1., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 1.]]],


        [[[1., 0., 1., 1.],
          [0., 0., 0., 0.],
          [1., 0., 1., 1.],
          [1., 0., 1., 1.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 1., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]])

Thanks in advance, appreciate any help :blush: