How extract one probability = one caption from simple caption CNN+LSTM model?

I’m trying to extract multiple sentences (one probability = one caption) from simple caption CNN+LSTM model. I’m getting multiple probabilities for one sentence (caption)

main model code:

    sampled_ids = []
         inputs = features.unsqueeze(1)
             for i in range(self.max_seg_length):
        hiddens, states = self.lstm(inputs, states)             # hiddens: (batch_size, 1, hidden_size)
        outputs = self.linear(hiddens.squeeze(1))               # outputs:  (batch_size, vocab_size)
        _, predicted = outputs.max(1)                           # predicted: (batch_size) 
       
      
        # The code start here 
        sm = torch.nn.Softmax(dim=1)
        #update
        outputs1 = sm(outputs)  #output shape (1, 9956)
        top1_prob, top1_label = torch.topk(outputs1,1)
        #print(top1_label)
        #print('--------')
        print(top1_prob)


        sampled_ids.append(predicted)
        inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
        inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
    sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
    return sampled_ids

Result for just top-1 caption

tensor([ 0.9998])
tensor([ 0.4791])
tensor([ 0.3699])
tensor([ 0.9963])
tensor([ 0.5529])
tensor([ 0.1465])
tensor([ 0.2513])
tensor([ 0.9950])
tensor([ 0.7264])
tensor([ 0.9951])
tensor([ 0.3070])
tensor([ 0.9992])
tensor([ 0.4416])
tensor([ 0.9996])
tensor([ 0.4754])
tensor([ 0.6424])
tensor([ 0.9996])
tensor([ 0.5170])
tensor([ 0.5675])
tensor([ 0.9996])

Can you also print output1’s shape?

print(outputs1.shape)
(1, 9956)

this is weird
torch.topk by default works on last dimension and should return (1, 1) probs

Could you please check top1_prob’s dimensions as well?

You right the top1 prob dimension (1, 1) but the probabilities still look wrong
(0.9997780919075012, 1)
(1, 1)
(0.47908642888069153, 2)
(1, 1)
(0.36988458037376404, 3)
(1, 1)
(0.996319055557251, 4)
(1, 1)
(0.5528988838195801, 5)
(1, 1)
(0.1464884728193283, 6)
(1, 1)
(0.2512759864330292, 7)
(1, 1)
(0.9949786067008972, 8)
(1, 1)
(0.7263955473899841, 9)
(1, 1)
(0.9951314926147461, 10)
(1, 1)
(0.3069855272769928, 11)
(1, 1)
(0.999177873134613, 12)
(1, 1)
(0.44164711236953735, 13)
(1, 1)
(0.9995712637901306, 14)
(1, 1)
(0.47538724541664124, 15)
(1, 1)
(0.6424259543418884, 16)
(1, 1)
(0.999626874923706, 17)
(1, 1)
(0.5170174837112427, 18)
(1, 1)
(0.5674574971199036, 19)
(1, 1)
(0.9995905756950378, 20)

here is similar code for a vanilla caption