How to train nn.Transformer without Teacher Forcing?

I have this Transformer here:

self.src_word_embed = nn.Embedding(num_embeddings=num_words, embedding_dim=dim_model)
self.pos_embed = PositionalEncoding(dim_model=dim_model, max_len=max_seq_len, dropout=DROPOUT)
self.tgt_word_embed = nn.Embedding(num_embeddings=num_words, embedding_dim=dim_model)
self.transformer = nn.Transformer(d_model=dim_model, nhead=heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
self.out = nn.Linear(dim_model, num_words)

When training, I use

outputs = model(src_seq=sec_seq, tgt_seq=tgt_seq, tgt_mask=tgt_mask)
# size = [batch_size, tgt_len, num_words]

The training loss rapidly down to 1e-5, however when using greedy decoding to inference, the result isn’t that good. So how could I use greedy decoding to train the model where the torch.argmax() used in greedy_decoding won’t keep grad_fn

I have my own implement here but it’s so slow!!!

        if use_teacher_forcing:
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz=tgt_len, device=device)
            # size = [batch_size, max_tgt_seq_len]
            outputs = model(src_seq=src_seq, tgt_seq=tgt_seq, tgt_mask=tgt_mask)
        else:
            for iii in range(inputs.size(0)):
                # Get encoder output
                src_seq = torch.unsqueeze(inputs[iii, :], 0)
                src_key_padding_mask = (src_seq == PAD_TOKEN).to(device)
                src_word_embed = model.src_word_embed(src_seq)
                src = model.pos_embed(src_word_embed)
                enc_outputs = model.transformer.encoder(src=src, src_key_padding_mask=src_key_padding_mask)
                # Initialize decoder output
                dec_result = torch.Tensor([[START_TOKEN]]).to(torch.int64).to(device)
                for iiii in range(tgt_len):
                    tgt_word_embed = model.tgt_word_embed(dec_result)
                    dec_input = model.pos_embed(tgt_word_embed)
                    dec_outputs = model.transformer.decoder(tgt=dec_input, memory=enc_outputs)
                    projected = model.out(dec_outputs)
                    # size = [1, dec_result.size(0), num_words]

                    prob = F.softmax(torch.squeeze(projected, 0), dim=-1)
                    idx = torch.argmax(prob, dim=-1)
                    next_symbol = idx[-1]
                    dec_result = torch.cat([dec_result, torch.Tensor([[next_symbol]]).to(src_seq.dtype).to(device)], -1)

                    if iiii == 0:
                        output = projected[:, -1, :]
                    else:
                        output = torch.cat((output, projected[:, -1, :]), 0)
                if iii == 0:
                    outputs = torch.unsqueeze(output, 0)
                else:
                    outputs = torch.cat((outputs, torch.unsqueeze(output, 0)), 0)

        loss = loss_fn(outputs.view(-1, outputs.size(-1)), tgt_y.view(-1))  # CrossEntropy
        loss.backward()
        optimizer.step()

Is there any methods to keep grad_fn after argmax or to prevent that many times of torch.cat?

argmax will never keep the grad_fn because it returns an integer (but you said this is for inference? and you’re not using dec_result in the loss anyway?)

To prevent doing cat as many times you could just append to a list, and cat a single time at the very end