CNN-LSTM Image Captioner: Testing Output Problems

def forward(self, features, captions): #for the training phase
    #e.g input feature -> (b, 768), caption -> (b,seq_length, vocab_size)
    features = features.unsqueeze(1)
    # features = (b,1,758)
    h_0= torch.zeros(self.num_layers, features.shape[0], self.hidden_size).to(device)
    c_0 = torch.zeros(self.num_layers, features.shape[0], self.hidden_size).to(device)
    # teacher forcing
    if captions is not None:       
        #stacking image features and caption embeddings in 1d array
        all_embeddings = torch.cat((features, captions[:,:-1,:]), dim =1) #captions shifted
        # (b,1,embed_dim) + (b,50-1,embed_dim)
        #LSTM
        logits, _ = self.lstm(all_embeddings)
        
        #linear that feeds to the next LSTM cell and also contains the previous state
        prob_dist = self.linear(logits)
        # softmax handled is cross entropy
        return prob_dist
    
    else:
        inputs = features 
        states = (h_0,c_0)
        decoder_output = torch.zeros(features.shape[0],self.word_max,self.vocab_size)
        # decoder_output (b, seq_length, vocab_size) stores output for batch 
        for t in range(self.word_max):
            # inputs = (b,1, embed_dim = 768)
            output_step, states = self.lstm(inputs,states)
            # output_step = b,1,hidden_size 
            logits = self.linear(output_step)
            # logits (b,1, vocab_size)
            pred_word = self.greedy_word(logits)
            #(b,1,3000) one hot to plug into decoder_output for loss calc in testing
            decoder_output[:, t:t+1, :] = pred_word
            # fill in the n th  word for each caption
            
            # now embed next word to feed into lstm
            input_ids = logits.argmax(dim=-1).to(device)
           # (b,1)
            bert_outputs = self.bert_model(input_ids,attention_mask=None)
            inputs = bert_outputs.last_hidden_state
            # (b,1, 768)
        return decoder_output

[/quote]
Training is excellent. 90%+ accuracy. Testing I’m getting nonsensical sentences like “on there there there there.” I am using BERT embeddings. I’m new to both this forum and machine learning so any help would be appreciated.