Memory_mask in nn.Transformer

I’m implementing training codes of transformer model using nn.Transformer.

In the documents, there is a memory_mask optional argument. I read the document but I don’t understand the purpose of this argument.

Could you explain what memory_mask is?

Additionally, is there any code that uses nn.Transformer module?

3 Likes

It’s an attention mask working on the second input of transformer decoder layer. Within the encoder-decoder architecture, it works on the output of transformer encoder, which we call it “memory”.

Then, Why is the shape of memory_mask (T, S)?

1 Like

because the memory_mask works on the multihead_attn layer in decoder.

I am wondering how to generate the memory mask? The generate_square_subsequent_mask function can only generate square masks, but memory_mask requires the dimension (T, S). I am wondering is there a built in function in transformer?? Thank you!

2 Likes

maybe change it to something like this,

def _generate_subsequent_mask(tgt_sz, src_sz):
    mask = (torch.triu(torch.ones(src_sz, tgt_sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask
2 Likes