Hi! I am always encountering OOM in this line of
MultiheadAttention when training Transformer:
attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") )
Why would this happen?
key_padding_mask are all there already. If this is indeed taking a lot of memory, how should I reduce the memory cost? Also, would it improve the memory performance if I use a bunch of