nn.Transformer explaination

I am having hard time making the new transformer work. Following code has unexpected(to me) output. Gradients for the model parameters are zeros and so the optimizer step is of no use. The documentation for this module is not as explanatory as other like RNN. If someone can explain me how to make a encoder decoder Transformer work that that will be great.

code

import torch
print(torch.__version__)
X = torch.tensor([[[95.0]], [[100.0]], [[105.0]], [[110.0]], [[115.0]]])
y = torch.tensor([[[120.0]]])
print(X.shape, y.shape)
print(X.requires_grad, y.requires_grad)
model = torch.nn.Transformer(d_model=1, nhead=1, dim_feedforward=100, dropout=0)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1)
parms = [j for j in model.parameters()][:3]
model.train()
optimizer.zero_grad()
y_pred = model(X,y)
print(y_pred)
print(y)
print(y_pred.requires_grad)
print(y_pred._grad)
loss = criterion(y_pred, y)
print(loss)
for i in parms: print(i._grad)
loss.backward()
print(y_pred._grad)
for i in parms: print(i._grad)

output

1.2.0
torch.Size([5, 1, 1]) torch.Size([1, 1, 1])
False False
tensor([[[-5.7748e-11]]], grad_fn=<NativeLayerNormBackward>)
tensor([[[120.]]])
True
None
tensor(14400., grad_fn=<MseLossBackward>)
None
None
None
None
tensor([[0.],
        [0.],
        [0.]])
tensor([0., 0., 0.])
tensor([[0.]])

expected output non zero gradients for the model parameters
I am trying to train a encoder decoder to complete a sequence of numbers for example
input = 95, 100, 105, 110, 115 the corresponding output = 120.
I donā€™t know much about transformer thatā€™s why I tried to make this.
also the example on the documentation site is not making sense to me

>>> transformer_model = nn.Transformer(src_vocab, tgt_vocab)
>>> transformer_model = nn.Transformer(src_vocab, tgt_vocab, nhead=16, num_encoder_layers=12)

what is src_vocab and tgt_vocab ? I know what they mean but which parameters of the nn.Transformer constructor are the corresponding to?

2 Likes

Iā€™m also having a problem here. Can someone who is kind and smart dudes tell about this plz.

ā€˜what is src_vocab and tgt_vocab ? I know what they mean but which parameters of the nn.Transformer constructor are the corresponding to?ā€™

Iā€™m having the same problem, but for the example part i guess it is a mistake from their side
nn.transformer doesnā€™t take source and target vocab size as it is only implementing the transformer part without the embeddings layer on the input data and without the linear layer on the output of the decoder,
in order to make it work d_model will be your embedding size and call an embedding layer on the source and on the target and the output of the transformer should pass through a linear that gets you the target vocab size

self.embed_src = nn.Embedding(src_vocab, emb_dim)
self.embed_trg = nn.Embedding(trg_vocab, emb_dim)
self.model = nn.Transformer( d_model = emb_dim,nhead=heads, self.num_encoder_layers=N, num_decoder_layers=N)
self.out_linear = nn.Linear(emb_dim, trg_vocab)

for the forward function it should be

src = self.embed_src(src) 
trg = self.embed_trg(trg)
output = self.model(src, trg)
output = self.out(output)
2 Likes

There is a typo in the doc and the PR is going to fix it. You donā€™t need src_vocab and tgt_vocab to initiate the transformer module. Here is a simple example:

import torch
import torch.nn as nn
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt)

A full example to apply nn.Transformer for the word language model could be found here. The example applies both RNN and transformer modules, respectively.

2 Likes

Can someone explain the src and the src_mask shape of transformer.
For example, I have a tokenized text sentence with max_len=128.
This sentence go through a nn.Embedding(src_vocab=5000, emb_dim=128)
The output of the embedding will be a tensor with shape (N, 128,128), where N=batch_size.
The transformer docs tell that src input and src_mask have shape:
src: (S,N,E) and src_mask: (S,S)
where S is the source sequence length, N is the batch size, E is the feature number.
Should I do some changes on embedding output to use as input on transformer layer?
Iā€™m a bit confused :confused: .

4 Likes

S is the number of elements; N is the number of batches; E is the number of features (a.k.a. embedding dimension in your case).

If you send input (S, N, 5000) to embedding layer, the output will be in the shape of (S, N, 128). Then, you donā€™t need to make any changes in order to feed them to the transformer layer. The src_mask is just a square matrix which is used to filter the attention weights.

See example here

1 Like

Thanks for your reply!!
Iā€™m a bit confusing with this embedding layer output. Iā€™ll try explain:

My sentences have size: torch.size([128]).
So, if Iā€™m using a batch size of 32 the tensor will have size:
torch.size([32,128]) - > shape = (N, S)
When I send this tensor to the embedding layer (with src_vocab = 5000 and emb_dim=128) the output will have size:
torch.tensor([32, 128, 128]) -> shape = (N, S, E).
This is confusing me, should I permute first and second dimensions to become shape = (S, N, E) ?

2 Likes

yeap. You should transpose your input after embedding layer.

For nn.Transformer, we chose the shape to be (S, N, E) and some NLP people use (N, S, E). There is nothing right or wrong and the switch between two shapes is fine.

3 Likes

Thank you!!! :smiley:

1 Like

hi, Iā€™m a bit confusing with src_mask and src_key_padding_mask, the explanation on pytorch docs are
src_mask ā€“ the additive mask for the src sequence (optional).
src_key_padding_mask ā€“ the ByteTensor mask for src keys per batch (optional).
In my opinion, src_mask 's dimension is (S,S), and S is the max source length in batch, so i need to send input src_mask (N,S,S) to the Transformer.I donā€™t know if i understand that correctly. I donā€™t understand the src_key_padding_maskā€™s explanation on website docs, this is confusing me.
for the provided example code ,
output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
set the [src/tgt/memory]_key_padding_mask are None as default, Iā€™m a little confused about this operation.

1 Like

@LiHaibo
First, both masks work on the dot product of query and key in the ā€œScaled Dot-Product Attentionā€ layer.
src_mask is working on the matrix with a dimension of (S, S) and add ā€˜-infā€™ to a single position. src_key_padding_mask is more like a padding marker, which masks a specific tokens in the src sequence (a.k.a. the entire column/row of the attention matrix is set to ā€˜-infā€™).

2 Likes

@zhangguanheng66 Thanks for the explanation.
Just to check whether I understand correctly:we should provide the sequence padding mask in src_key_padding_mask and the dimension would be (N, S) where N is the batch size and S is the sequence length. I have confusion what will be content of src_key_padding_mask? will it be -inf/0 matrix or a boolean matrix with True/False?

padding mask is (N, S) with boolean True/False. Src_mask is (S, S) with float(ā€™-infā€™) and float(0.0). There is a note in pytorch nn.Transformer docs.

Hi @zhangguanheng66, @akashs, @LiHaibo

Can you please tell me what is the difference between the two sets of masks viz. ***_mask and ***_key_padding_mask?

From the documentation in the source code, this is what I could deduce. But I am not very confident and hence would really appreciate it if you can correct me:

  • src_mask, tgt_mask and memory_mask should be used when we want to apply the same mask to all the sequences in the given batch.
  • src_mask, tgt_mask, tgt_mask, tgt_mask and memory_mask, tgt_mask should be used when we want to specify different masks for different samples in the given batch. Also, the way you specify the masks is slightly different from the previous one.

My question is: Do both set of masks achieve the same purpose? And should we be using either one of them?

For instance, if you want to create a Seq2Seq Transformer model with both TransformerEncoder and TransformerDecoder, is it ok, if I only specify src_mask, tgt_mask and memory_mask?

1 Like

@shahensha To your questions, key_padding_mask controls how which batch items are allowed to attend to certain key positions. This is most commonly used to avoid attending to padding elements. attn_mask controls how query positions are allowed to attend to key positions. This is useful for doing left-to-right (causal) attention, where we enforce that query positions are only allowed to attend to keys to their left.

4 Likes

Thank you @zhangguanheng66

I finally understood all the different masks in the API. But for some reason, my system is not able to work well at inference time. The loss goes does nicely, but at inference it just produces garbage values.

Iā€™m having a hard time understanding how to use nn.Transformer, too, even after reading through this thread, the tutorial, this github issue, and the example language model. My model seems to do nothing but copy the target sequence, no matter what I do.

The task is to predict the title of an article, given a sentence from the article. Itā€™s just a test task for a similar task I would like to do. The sentence and the title are both of varying length. To facilitate batching, I use data loader collate_fn to pad every sentence in a batch to the length of the longest sentence in the batch. Same for title. While using nn.Transformer, I make the sentence the src, and the title the tgt.

I include a padding mask for both src and tgt, which has True values wherever I padded a sentence. I also include a tgt_mask generated by generate_square_subsequent_mask to make it so that the decoder canā€™t look ahead in a sequence while itā€™s predicting. Since the model was still copying everything, I also included a square mask for the src, but that didnā€™t help anything.

I feel that Iā€™m missing something very obvious. Can anybody help?

Looping in @zhangguanheng66 who seems to know a lot about this.

For your first part, it seems that you are not setting up attn_mask correctly.

Wow, thanks for the quick reply.

Which attn_mask is that? Both source and target masks should be pretty standard

Hereā€™s how Iā€™m using it, where self.base is just a model that returns embeddings for inp (src) and tgt, and where src_mask and tgt_mask are the standard upper triangle matrices, and src/tgt_key_padding_mask are as I described previously:

inp_emb, tgt_emb = self.base(inputs, targets)
# We get inputs and targets in (N, S, E) and (N, T, E), and nn.Transformer requires (S, N, E) and (T, N, E), so we transpose them
inp_emb = inp_emb.transpose(0, 1)
tgt_emb = tgt_emb.transpose(0, 1)

hdn = self.transformer(inp_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=inp_padding_mask, tgt_key_padding_mask=tgt_padding_mask)

out = self.head(hdn)
out = out.transpose(0, 1)

loss_fct = nn.CrossEntropyLoss()
out_view = out.contiguous().view(-1, self.vocab_size)
tgt_view = targets.view(-1)
loss = loss_fct(out_view, tgt_view)

Could the transposes be throwing it off?

Well I was right, I was indeed missing something very obvious. To anyone who comes after me and has a similar problem, the reason why my network was only copying results was because my training strategy was wrong. I was passing in targets to the decoder and calculating loss based on how similar what it produced was to those targets. If you think about it, I was asking the decoder to behave like an auto-encoder, to reproduce exactly what I passed in. Thatā€™s not very difficult for a transformer decoder to do, so it learned to copy very quickly, even with masks. Doing this also makes it impossible to perform inference, since the decoder never learned how to generate anything new.

How, you might ask, do you fix this? The solution for me was a couple steps:

  1. To add special start and end tokens to every target; e.g. [ 'h', 'e', 'l', 'l', 'o'] became [ <start>, 'h', 'e', 'l', 'l', 'o', <end>] (since itā€™s a character model, my start and end tokens are actually unicode tokens)
  2. To add an additional loop in the training loop that starts with a target of length 1 and passes incrementally larger targets until it passes the entire target. Then calculate loss based on how similar the output is to the target shifted left by one. (I also do backpropagation each time ā€“ not sure if thatā€™s correct or if they should be aggregated over the whole sub-loop.) E.g. [<start>] goes in, ['h'] is expected. Then [<start>, 'h'] goes in, ['h', e'] is expected. And so on. The last iteration is [<start>, 'h', 'e', 'l', 'l', 'o' ], with [ 'h', 'e', 'l', 'l', 'o', <end>] expected. This particular way of training is called teacher forcing. It also sets us up nicely to perform inference.

Inference (answering this issue now) then happens by simply passing the hidden state from the encoder and the [<start>] token to the decoder. Since the model has been trained to output a single token when a single <start> token is passed in, it should output (hopefully) the correct first token of our output sequence. Then, we can take that token and append it to our <start> token, and pass in that as input to the decoder. Now it should generate two tokens. We repeat this process until the <end> token is generated, and then we stop. This is known as greedy decoding. Both teacher forcing and greedy decoding are used to train Googleā€™s T5, so theyā€™re viable today. There is, however, a method called beam search that gets better results, but takes much longer to generate.

6 Likes