nn.TransformerDecoder and nn.TransformerEncoder inference doesn't work

I am using customized Transformer with nn.TransformerEncoder and nn.TransformerDecoder layer . It seem like nn.TransformerDecoder layer doesn’t support inference process(generation/testing), like sending token id one by one with fixed memory generated from nn.TransformerEncoder layer. I am wondering is there a tutorial that I can refer to as I didn’t find a tutorial in the official documents. Thank you in advance for your help!
Encoder is Bert plus nn.TransformerEncoder and Decoder is nn.TransformerDecoder
Here’s my codes:

    def greedy_inference(self,
                          src_input_ids,
                          max_tgt_seq_len,
                          src_token_type_ids,
                          src_attention_mask,
                          tgt_input_ids=None,
                          inference=True):
         with torch.no_grad():
             batch_size = 1
    
             if tgt_input_ids is not None:
                 tgt_seq_len = tgt_input_ids.shape[1]
             else:
                 assert max_tgt_seq_len is not None, 'Target sequence length do not defined'
                 tgt_seq_len = max_tgt_seq_len
    
             # encoder
             encoder_vec = self.bert_encoder(src_input_ids, src_token_type_ids, src_attention_mask)
             # create initial tokens: shape tensor([[101]])
             generated_seq = torch.full((batch_size, 1), Constants.BOS, dtype=torch.long, device=self.device)
    
             for i in range(1, tgt_seq_len):
                 decoder_outputs = self.bert_decoder(encoder_vec, generated_seq, inference)
                 _, lsm_score = self.generator(decoder_outputs)
                 # Take token with largest probability and use it for the next words
                 generated_token_index = lsm_score[-1, :, :]  # lsm_score shape [num_tokens, bs, vocab_size]
                 _, generated_token_index = torch.topk(generated_token_index, 1, dim=-1)
                 # Concatenate generated token with sequence
                 generated_seq = torch.cat((generated_seq, generated_token_index), dim=-1)
    
         return generated_seq