Hi! I am always encountering OOM in this line of MultiheadAttention
when training Transformer:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
Why would this happen? attn_weights
, key_padding_mask
are all there already. If this is indeed taking a lot of memory, how should I reduce the memory cost? Also, would it improve the memory performance if I use a bunch of del
’s and torch.cuda.empty_cache()
inside MultiheadAttention
.