Inputs for LSTM with mini-batches for next word prediction

Hello!

Could you, please, tell me please, how do I calculate the loss function for the next word prediction.

Here are all the steps:

For example, a have N sentences, and mini-batch-size = 2

  1. I get mini-batch of sentences, for example:

    [ 6, 7, 8 ]
    [ 1, 2, 3, 4, 5 ]

  2. Sort mini_batch by the length

    [ 1, 2, 3, 4, 5 ]
    [ 6, 7, 8 ]

  3. Split each sentence into X and Y, get a list of lengths:

    X = sentence[     :  -1 ] 
    Y = sentence[  1 :      ]
    
    X =      [[ 1, 2, 3, 4 ]
                 [ 6, 7  ] ]
    
    Y =      [[ 2, 3, 4, 5 ]
                 [ 7, 8 ] ]
    
    X_lengths = [4, 3]
    
  4. Pad sentences by Pad_Index ( 0 in my case):

         X =      [[ 1, 2, 3, 4 ]
                      [ 6, 7, 0, 0 ] ]
    
         Y =      [[ 2, 3, 4, 5 ]
                      [ 7, 8, 0, 0 ] ]
    
  5. model.zero_grad()

  6. get embeddings

    (word_embeddings = nn.Embedding(vocab_size,

    embedding_dim,

    padding_idx=0))

    X = word_embeddings(X)

    X.size() = torch.Size([mini_batch_size, seq_length, embedding_dim])

  7. Transpose:

    X = torch.transpose(X, 0, 1)

    X.size() = torch.Size([seq_length, mini_batch_size, embedding_dim])

  8. Pack by:

    X = torch.nn.utils.rnn.pack_padded_sequence(X, X_lengths)

  9. Initialise hidden units:

    h_t = torch.Size([ layers_dim, mini_batch_size, hidden_dim])
    h_c = torch.Size([ layers_dim, mini_batch_size, hidden_dim])

  10. LSTM:

    lstm_out, (h_t, h_c) = lstm(X, (h_t, h_c))

  11. Pad by:

    lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)

    lstm_out.size() = torch.Size([seq_length, mini_batch_size, hidden_dim])

  12. Linear layer:

    fc = nn.Linear(hidden_dim, vocab_size)

    linear_out = self.fc(lstm_out)

    linear_out.size() = torch.Size([seq_length, mini_batch_size, vocab_size])

  13. softmax:

    Y_hat = F.log_softmax( linear_out, dim=1)

    Y_hat.size() = torch.Size([seq_length, mini_batch_size, vocab_size])

So, the question is:

How do I get the loss, if:

Y_hat ->  torch.Size([seq_length, mini_batch_size, vocab_size])
Y        ->  torch.Size([mini_batch_size, seq_length])

And am I doing previous steps right?

Thanks in advance!