Faster block diagonal attention mask


I am trying to implement a block diagonal attention mask. My input is the size of each block, and I have different sizes depending on the sample in the batch. Usually the blocks are small (size between 1 and 10). My current solution is to loop through the batch to create the block diagonal masks, and then concatenate them:

all_sizes = [[2, 0], [1, 1]]

all_masks = []
for sizes in all_sizes:
    mask = torch.block_diag(*[torch.ones((s, s)) for s in sizes])

attention_mask = torch.stack(all_masks, dim=0)
# Output:
# [[[1.0, 1.0], 
#  [1.0, 1.0]], 
# [[1.0, 0.0], 
#  [0.0, 1.0]]]

However, since my batch sizes may get large, this results in a lot of calls to torch.block_diag which takes too much time. I have tried a few different solutions but none is faster than the previous one.

Do you have any idea how I could write this operation to speed it up? In particular given the fact that I am not creating “any” block diagonal tensor, but an attention mask that only has ones on the diagonal?

Thanks a lot !