Reducing memory cost during `masked_fill`

Hi! I am always encountering OOM in this line of MultiheadAttention when training Transformer:

attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
            )

Why would this happen? attn_weights, key_padding_mask are all there already. If this is indeed taking a lot of memory, how should I reduce the memory cost? Also, would it improve the memory performance if I use a bunch of del’s and torch.cuda.empty_cache() inside MultiheadAttention.