nn.Transformer explaination

If you switch to transformer encoder and have the triangle src_mask, you should be able to predict the next word, just like this example

With just an encoder, wouldn’t the size of the output be limited to the size of src? That is, if I have a sentence the cow jumped over the moon (28 characters), then the maximum length of the predicted title is 28 characters. But with the encoder-decoder, the sentence can be any length and the output can be any length, which I want.

Thanks for your helpful comments here! I am very grateful to see someone who knows what to do. I would like to apply left to right causal attention so that I get a representation for each timepoint in my time series that I can use to make predictions. Do you know of any successful examples of applying this left to right causal attention?

For transformer encoder, the output sequence has same size as the input (a.k.a. src)

Do you think the word language modeling task supports this?

It means the mask attention is still not setup properly. Could you send a code snippet?

I’m still confused with trg_key_paddding_mask. As discribed in the doc, trg_key_paddding_mask’s dimension is (N,S). In translation tasks, the decoder inputs need to mask future words and pad. However, It seems that trg_key_paddding_mask can only masks pad. Does the nn.Transformerdecoder mask the future word in the source code(I have’t found that)?. May be trg_mask (T, T) can mask the future word but it dosen’t work when the input is a batch.

To mask future tokens, you should use src_mask, tgt_mask, memory_mask

Hi, I do not understand why both src and tgt are required for nn.Transformer.

Let’s say for machine translation use case, I understand that during training, src and tgt are 2 different languages. But during testing, given an input, predict an output, we do not have tgt. If so, what should the tgt input? The start of sentence token (e.g. <sos>)?

1 Like

Yes. The tgt input will be, as you rightly said, <sos>.

Thanks, but src_mask only works when the input is a sequence not batch.