How to use nn.TransformerDecoder() at inference time

Hello. I am using nn.TransformerDecoder() module to train a language model. During training time, the model is using target tgt and tgt_mask, so at each step the decoder is using the last true labels.
However, for text generation (at inference time), the model shouldn’t be using the true labels, but the ones he predicted in the last steps. Can we do that with nn.TransformerDecoder() ? Or should I reimplement the module to add the new predicted labels at each step ?

4 Likes

Were you able to get an answer on this? Is there an example?

1 Like

An example with nn.Transformer module for word language model here.
Or a tutorial here

The example you posted is for the TransformerEncoder. I’m looking for an explicit usage example of the TransformerDecoder.
Can you provide one?

1 Like

So for TransformDecoder, what’s your target and memory sequences?

Yes. You need to initialized the target with a tensor of SOS (start of sentence). At each step, you append this tensor with the pred, as explained in the paper Attention Is All You Need:

At each step the model is auto-regressive[10], consuming the previously generated symbols as additional input when generating the next

import torch
import torch.nn as nn

device = 'cpu'
d_model = 768
bs = 4
sos_idx = 0
vocab_size = 10000
input_len = 10
output_len = 12

# Define the model
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=4).to(device)

encoder = nn.TransformerEncoder(encoder_layer=encoder_layer,
                                num_layers=6).to(device)

decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=4).to(device)

decoder = nn.TransformerDecoder(decoder_layer=decoder_layer,
                                num_layers=6).to(device)

decoder_emb = nn.Embedding(vocab_size, d_model)

predictor = nn.Linear(d_model, vocab_size)

# for a single batch x
x = torch.randn(bs, input_len, d_model).to(device)
encoder_output = encoder(x)  # (bs, input_len, d_model)

# initialized the input of the decoder with sos_idx (start of sentence token idx)
output = torch.ones(bs, output_len).long().to(device)*sos_idx
for t in range(1, output_len):
    tgt_emb = decoder_emb(output[:, :t]).transpose(0, 1)
    tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(
        t).to(device).transpose(0, 1)
    decoder_output = decoder(tgt=tgt_emb,
                             memory=encoder_output,
                             tgt_mask=tgt_mask)

    pred_proba_t = predictor(decoder_output)[-1, :, :]
    output_t = pred_proba_t.data.topk(1)[1].squeeze()
    output[:, t] = output_t

#output (bs, output_len)

My final model looks like this:

class Transformer(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x, y=None, mode='train'):
        encoder_output = self.encoder(x)

        if mode == 'train':
            tgt_emb = self.decoder_emb(y).transpose(0, 1)
            tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(
                tgt_emb.size(0)).to(self.device).transpose(0,1)
            decoder_output = self.decoder(tgt=tgt_emb,
                                          tgt_mask=tgt_mask,
                                          memory=encoder_output)
            return self.predictor(decoder_output).permute(1, 2, 0)
        elif mode == 'generate':
            # the solution I gave before
            return 
        

5 Likes

Do you need tgt_mask in the inference? I am wondering is the for loop in inference processing do the same job as subsquent mask? Also, I am wondering have you done a experiment with nn.TransformerDecoder? I have some trouble using it, as when inference, it always return the same token. I really appreciate if you can help! Thank you!

I am learning seq2seq by working on a machine translation problem, from DE to EN.

Train and validation worked well. I got acceptable losses, and printing out the source, target and predicted sentence is acceptable too.

But for test (generate), I appended previously generated symbols as additional inputs for generating the next symbol. But my results output the same symbol again and again.

For example, the output I got is:

<sos> themselves themselves themselves themselves themselves themselves themselves themselves themselves 

Any idea what might have gone wrong?

Example output after every generated symbol:

tensor([2])
tensor([  2, 700])
tensor([  2, 700, 700])
tensor([  2, 700, 700, 700])
tensor([  2, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700])
tensor([  2, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700])
1 Like

I suspect you wouldn’t need a tgt_mask during inference since there would be nothing to mask. Only a shifted SOS token.

Can post this code somewhere where I could see it?

Did you happen to solve this issue? I am getting the same repeated values as well… Strange since I made sure that I accounted for the required shift in the output for the decoder…

Hi Im facing the same problem. How did you ensure the shift? Im not sure how to do it: Im using padded sequences as target input: e.g. target_tensor = [1,2,3,4,5,99,0,0,0,0,0,0] (1 is sos-token, 99 is eos-token, cutting the last element doesnt make sense here, since it s a padding token.) I tried to switch the token with target_tensor[target_tensor == 99] = 0 whcih results in an grading error.

Also maybe @Agnes could you provide more parts of the code Im not sure where to initialize the code? In init? or outside the model?

Looking at this example none of the things @agnes wrote seems necessary? Is there any explanation for that?

Adding this discussion which descibes a similar apporach to @agnes

1 Like

@asdf11 Can you check if your training outputs give values that do not repeat? I found that that was an issue I was facing. This issue of repeated training outputs is still an issue for me and I suspect that ‘tgt_mask’ doesnt really work :frowning:

1 Like

Does anyone have a code snippet about using the TransformerDecoder properly?

Regards

1 Like

@mathematicsofpaul - one thing I just noticed after banging my head against this for the last day - I removed the PositionalEncoding layer (from https://pytorch.org/tutorials/beginner/transformer_tutorial.html) and that “fixed” the problem. Before I did this it simply repeated the same token (both training and validation).

Not sure if there’s a bug in it, I suspect that it’s necesary for a real transformer (I’m just trying to get it to work with a couple of training cases and using the same cases for validation) so I guess I’ll have to look closer at what it’s doing (and what it’s supposed to do)

@david.waterworth thanks for posting! Woah, I believe the Positional Encoding module is a must since it actually gives order to the words you input into the model. If you included a sentence without Positional Encoding, that basically means you’re giving the model a bag of words instead of a sequence of words. Regardless, it might be that there is a bug in the Positional Encoding - maybe. Could you please post your code in a google colab codebook or simply post your model here?

Yeah @mathematicsofpaul , I doubt my model will work for a real dataset without it. This is how I implemented my forward function. I’m going to spend tomorrow looking more closely at the PositionalEncoding

def forward(self, src, tgt):

    # create padding masks
    src_key_padding_mask = (src == self.src_vocab.stoi['<pad>'])  # N S
    tgt_key_padding_mask = (tgt == self.tgt_vocab.stoi['<pad>'])  # N T
    memory_key_padding_mask = src_key_padding_mask.clone()        # N S

    # prevent attention of future output tokens
    tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.shape[-1])  # T T

    src = rearrange(src, 'n s -> s n')
    tgt = rearrange(tgt, 'n t -> t n')

    src = self.embed_src(src) * math.sqrt(self.d_model) # S, N, E
    tgt = self.embed_tgt(tgt) * math.sqrt(self.d_model) # T, N, E

    output = self.transformer(
        src, tgt, 
        tgt_mask=tgt_mask, 
        src_key_padding_mask=src_key_padding_mask,
        tgt_key_padding_mask=tgt_key_padding_mask, 

    output = rearrange(output, 't n e -> n t e')
    return self.fc(output)

and my train and validation are the same

        # reshape inputs
        src = rearrange(batch.object_name, 's n -> n s')
        tgt = rearrange(batch.tags, 't n -> n t')   

        # Create tgt_inp and tgt_out (which is tgt_inp but shifted by 1)
        tgt_inp, tgt_out = tgt[:,:-1], tgt[:,1:]

        optim.zero_grad()
        out = model(src, tgt_inp)
        #predicted = F.log_softmax(out, dim=-1).argmax(dim=-1)
        loss = criterion(rearrange(out, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

        # Backpropagate and update optim
        loss.backward()
        optim.step()

@david.waterworth What did you set for your self.transformer? I assume nn.Transformer(). Great let me know about the Positional Encoding! I too have been banging my head about this.

Also, check out this very informative thread.

Yes it’s nn.Transformer()

@mathematicsofpaul I don’t think the problem was with the position encoding. I think it was the initialisation. I’d copied the code below and applied it to the full model - it appears to mess something up - when I removed it and added the position encoding back training converges to a much lower loss and the decoder stops repeating the same token over and over.

for p in model.parameters():
    if p.dim() > 1:
        torch.nn.init.xavier_normal_(p)

Perhaps the code overwrites the position encoding as it’s stored in a buffer - although the docstring for ‘register_buffer’ does say it’s for model state that aren’t parameters? The other option is the embedding weights are initialised to zero - perhaps overwriting them with non-zeros messes with that layer?