How does bath_first work in GRU/LSTM

So I created a GRU for a sentiment classification task.

class GRU(nn.Module):
    def __init__(self, vocab_size, emb_size,hid_size, out_size, n_layers,bi,drop,pad_tok):
        self.embedding = nn.Embedding(vocab_size,emb_size,padding_idx=pad_tok)
        self.gru = nn.GRU(emb_size,hid_size,num_layers = n_layers , bidirectional = bi,dropout=drop,batch_first=True)
        self.out = nn.Linear(hid_size*2 ,out_size)
        self.dropout = nn.Dropout(drop)
    def forward(self, x,text_lengths):
        print("Input shape:",x.shape)
        if(0 in text_lengths):
        embedd = self.dropout(self.embedding(x))
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedd, text_lengths)
        out,hidden =self.gru(packed_embedded)
        print("Hidden shape:",hidden.shape)        
        hidden = self.dropout([-2,:,:], hidden[-1,:,:]), dim = 1))
        return self.out(hidden.squeeze(0))

So in both instances when I changed batch_first to True and False the model works and trains fine.But it shouldn’t as by setting it to true the model expects the first dimension to be the batch size which isn’t in my case (it’s sequence length).The funny thing is the model seems to perform a little bit better when I set it to True , which again I think it shouldn’t . Can someone explain what’s happening behind this?

PyTorch and other frameworks only complain if the dimensions don’t match expected values. Apart from that, anything goes and there’s always the chance that the network learns at least something. It just doesn’t mean that it’s something meaningful. Here just some food for thoughts and steps I would try:

  • When you say the network trains fine, what does that mean. Only the training loss going down is not a sufficient metric. How about the test loss/accuracy.
  • pack_padded_sequence also has batch_first as optional parameter. Did you try various combinations of True/False together with oyr definition of GRU. Intuitively, both should have the same value values for batch_first.
  • pack_padded_sequence also has the optional parameter enforce_sorted which is by default False. In this case, the method expects the input sorted w.r.t. the the lengths of the sequences. I cannot see in your code where you sort your input.