Pytorch Transformer model for language generation gives sequence of pad_index as output

I am using Transformer Encoder and Decoder to generate questions. The encoder encodes the text input and then the encoded input is combined with question type features before given to the decoder. The model uses teacher_forcing for training.

The problem:
After first batch, the model generates only pad_index sequence as output.

[code]

def train(epoch):
	model.train()
	train_loss = 0.0
	for idx, data in tqdm(enumerate(trainloader), total=len(trainloader)):

		data[1:] = [d.cuda() for d in data[1:]]
		loss = model(data[1:])

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()

	train_loss = train_loss/len(trainloader)
	print ("E: %d | L: %.2E"%(epoch, train_loss))
	
	
class PositionalEncoding(nn.Module):
    def __init__(self,
	         emb_size,
	         dropout,
	         maxlen):
	super(PositionalEncoding, self).__init__()
	den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
	pos = torch.arange(0, maxlen).reshape(maxlen, 1)
	pos_embedding = torch.zeros((maxlen, emb_size))
	pos_embedding[:, 0::2] = torch.sin(pos * den)
	pos_embedding[:, 1::2] = torch.cos(pos * den)
	pos_embedding = pos_embedding.unsqueeze(0)  # Needed when batch_size is the 0th dim

	self.dropout = nn.Dropout(dropout)
	self.register_buffer('pos_embedding', pos_embedding)   # ??

    def forward(self, token_embedding):
	return self.dropout(token_embedding + self.pos_embedding[:, :token_embedding.size(1)])


class WordEncoding(nn.Module):
    def __init__(self, vocab_size, emb_size):
	super().__init__()
	self.embedding = nn.Embedding(vocab_size, emb_size)
	self.emb_size = emb_size

    def forward(self, tokens):
	# make embeddings relatively larger
	return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


class TransformerModel(nn.Module):
    def __init__(
	self,
	args,
	num_q_types,
	use_glove,
	num_heads,
	num_encoder_layers,
	num_decoder_layers,
	feedforward_dim,
	dropout
    ):
	super().__init__()

	self.num_q_types = num_q_types
	self.use_glove = use_glove
	self.teacher_forcing_ratio = args.teacher_forcing_ratio
	self.max_q_len = args.max_q_len

	self.bos_index, self.eos_index, self.pad_index, self.unk_index = TripletModel.get_spl_word_index(args)
	self.src_pad_index = 0

	emb_file = os.path.join(args.data_dir, args.domain, "{}_top_q_embeds.emb".format(args.domain))
	emb_dict = torch.load(emb_file)

	src_embeds, label_start_idx, attr_val_dict = TripletModel.get_label_embeddings(args, emb_dict)
	src_vocab_size = len(src_embeds)

	tgt_embeds = TripletModel.get_word_embeddings(args, emb_dict)
	tgt_vocab_size = len(tgt_embeds)

	# Word and positional embeddings
	self.src_word_embedding = WordEncoding(src_vocab_size, args.embed_dim)  # (max_hop_len*3, embed_dim)
	self.tgt_word_embedding = WordEncoding(tgt_vocab_size, args.embed_dim)  # (max_q_len, embed_dim)

	self.src_position_encoding = PositionalEncoding(args.embed_dim, dropout, args.max_hop_len*3)
	self.tgt_position_encoding = PositionalEncoding(args.embed_dim, dropout, args.max_q_len-1)

	encoder_layer = TransformerEncoderLayer(d_model=args.embed_dim, nhead=num_heads, dim_feedforward=feedforward_dim, batch_first=True)
	self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
	decoder_layer = TransformerDecoderLayer(d_model=args.embed_dim, nhead=num_heads, dim_feedforward=feedforward_dim, batch_first=True)
	self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

	self.fc_out = nn.Linear(args.embed_dim, tgt_vocab_size)
	self.dropout = nn.Dropout(dropout)
	self.type_enc = nn.Linear(self.num_q_types, args.embed_dim)

    def create_src_mask(self, src):
	src_seq_len = src.shape[1]
	src_mask = torch.zeros((src_seq_len, src_seq_len), dtype=torch.bool, device=src.device)
	src_padding_mask = (src == self.src_pad_index)

	return src_mask, src_padding_mask

    def create_tgt_mask(self, tgt):
	def generate_square_subsequent_mask(sz):
	    mask = (torch.triu(torch.ones((sz, sz), device=tgt.device)) == 1).transpose(0, 1)
	    mask = mask.float().masked_fill(mask == False, float('-inf')).masked_fill(mask == True, float(0.0))
	    return mask

	tgt_seq_len = tgt.shape[1]
	tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
	tgt_padding_mask = (tgt == self.pad_index)

	return tgt_mask, tgt_padding_mask

    def forward(self, x):
	if self.training:
	    loss = self.train_forward(x)
	    return loss
	else:
	    with torch.no_grad():
	        pred = self.val_forward(x)
	        return pred

    def train_forward(self, x):

	# Get the input
	label_idxs, vision_feats, q_types, gt_q = x[0], x[1], x[2], x[3]

	# pred_q = self.forward_process(label_idxs, vision_feats, q_types, gt_q[:, :-1])  # input is excluding the last word
	pred_q = self.forward_process(label_idxs, vision_feats, q_types, gt_q[:, 1:])  # input is excluding the first word

	pred_q = pred_q.reshape(-1, pred_q.shape[-1])
	gt_q = gt_q[:, 1:].reshape(-1)  # output is excluding the first word

	# loss = F.nll_loss(pred_q, gt_q)
	loss = F.cross_entropy(pred_q, gt_q)

	return loss

    def val_forward(self, x):
	label_idxs, vision_feats, q_types = x[0], x[1], x[2]
	pred_q = self.forward_process(label_idxs, vision_feats, q_types)

	# pred_q = torch.argmax(pred_q, dim=2)

	return pred_q

    def forward_process(self, label_embeds, vision_feats, qtypes, gt_q=None):

	# label_embeds = label_embeds.resize(label_embeds.shape[0], self.attention_dim, self.embed_dim)
	# vision_feats = vision_feats.resize(vision_feats.shape[0], self.attention_dim, self.vision_feats_dim)

	# Encoder
	src = label_embeds.reshape(label_embeds.shape[0], -1)  # batch_size, max_hop_len*3
	src_mask, src_padding_mask = self.create_src_mask(src)  # (max_hop_len*3, max_hop_len*3), (batch_size, max_hop_len*3)
	src_emb = self.src_position_encoding(self.src_word_embedding(src))  # (batch_size, max_hop_len*3, embed_dim)

	memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask) # (batch_size, max_hop_len*3, embed_dim)

	# Encode qtype
	qtypes = (F.one_hot(qtypes, self.num_q_types)).type(torch.FloatTensor).to(memory.device).unsqueeze(dim=1)  # (batch_size, 1, num_q_types)
	qtype_enc = self.type_enc(qtypes)  # (batch_size, 1, embed_dim)
	memory = memory * qtype_enc  # (batch_size, max_hop_len*3, embed_dim)

	q_preds = None

	# Decoder
	if self.training:
	    use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
	else:
	    use_teacher_forcing = False
	# print("Teacher forcing", use_teacher_forcing)

	# Starting with <BOS>
	word_seq = torch.tensor([self.bos_index] * label_embeds.shape[0], device=label_embeds.device).unsqueeze(dim=1)  # (batch_size, 1)

	for i in range(self.max_q_len-1):
	    q_len = i+1
	    # print(q_len)

	    tgt_emb = self.tgt_position_encoding(self.tgt_word_embedding(word_seq))  # (batch_size, q_len, embed_dim)

	    if self.training:
	        tgt_mask, tgt_padding_mask = self.create_tgt_mask(gt_q[:, :q_len])  # (q_len, q_len), (batch_size, q_len)
	        preds = self.transformer_decoder(tgt_emb, memory, tgt_mask, None, tgt_padding_mask, src_padding_mask)  # (batch_size, q_len, embed_dim)
	    else:
	        tgt_mask, _ = self.create_tgt_mask(word_seq)  # (q_len, q_len), (batch_size, q_len)
	        preds = self.transformer_decoder(tgt_emb, memory, tgt_mask)  # (batch_size, q_len, embed_dim)

	    # preds = preds[:, -1, :]  # (batch_size, embed_dim)
	    # preds = self.fc_out(preds)  # (batch_size, vocab_size)
	    # preds = F.softmax(preds, dim=-1)  # (batch_size, vocab_size)
	    # preds = preds.unsqueeze(dim=1)   # (batch_size, 1, vocab_size)

	    preds = self.fc_out(preds)  # (batch_size, q_len, vocab_size)
	    preds = F.softmax(preds, dim=-1)  # (batch_size, q_len, vocab_size)

	    if i < self.max_q_len-2 and use_teacher_forcing:
	        # The next word is the GT word if use_teacher_forcing else the predicted word by the model
	        word_out = gt_q[:, i+1].unsqueeze(1)
	    else:
	        # word_out = preds[:, -1, :].argmax(dim=-1).unsqueeze(1)  # (batch_size, 1)
	        word_out = preds.argmax(dim=-1)[:, -1].unsqueeze(1)  # (batch_size, 1)

	    # q_preds = preds if q_preds is None else torch.cat([q_preds, preds], dim=1)
	    q_preds = preds

	    word_seq = torch.cat([word_seq, word_out], dim=1)
	    print(word_seq[:1])

	return q_preds if self.training else word_seq[:, 1:] # Don't use the <BOS>

[\code]

The output shows word_index sequences generated after every iteration for the first input of every batch.

[code]
0%| | 0/308 [00:00<?, ?it/s]
tensor([[ 0, 501]], device=‘cuda:0’)
tensor([[ 0, 501, 84]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494]],
device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42]],
device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710]],
device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42]],
device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494, 511]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494, 511, 691]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494, 511, 691, 42]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494, 511, 691, 42, 710]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494, 511, 691, 42, 710, 164]], device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494, 511, 691, 42, 710, 164, 42]],
device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494, 511, 691, 42, 710, 164, 42, 657]],
device=‘cuda:0’)
tensor([[ 0, 501, 84, 494, 691, 42, 42, 189, 42, 42, 494, 42, 710, 42,
154, 494, 691, 42, 494, 511, 691, 42, 710, 164, 42, 657, 189]],
device=‘cuda:0’)
0%| | 1/308 [00:01<05:41, 1.11s/it]
tensor([[0, 2]], device=‘cuda:0’)
tensor([[0, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2]], device=‘cuda:0’)
tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2]], device=‘cuda:0’)
1%| | 2/308 [00:02<05:45, 1.13s/it]

[\code]