GRU batch_first changing between layers


while reading about the ASR project implementation here Building an end-to-end Speech Recognition model in PyTorch I came across a GRU implementation that is unlike any other RNN/GRU/LSTM I have come across.

The reason why I am curious is that this implementation has outperformed every other network I have tried in my experiments.

The implementation is as follows:

This is the GRU:

class BidirectionalGRU(nn.Module):
    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
         super(BidirectionalGRU, self).__init__()

         self.BiGRU = nn.GRU(
             input_size=rnn_dim, hidden_size=hidden_size,
             num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
       x = self.layer_norm(x)
       x = F.gelu(x)
       x, _ = self.BiGRU(x)
       x = self.dropout(x)
       return x

Then in the main network class the multilayer GRU is created as follows:

self.birnn_layers = nn.Sequential(*[
                 BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                                              for i in range(n_rnn_layers)

What I don’t understand is why the first layer has batch_first=True and then all subsequent layers make use of batch_first=False.

If anyone is familiar with why this is being done I would really appreciate any help.

that’s done because output of the first GRU is time-major, despite batch first input

After looking at the docs I saw that the output is always time-major as you said. I wrote this test to confirm that this was what was going on

gru_bf = nn.GRU(input_size=512,

gru = nn.GRU(input_size=512,

batch = 1
seq = 100
input_size = 512

in_bf = torch.randn(batch, seq, input_size)
input = torch.randn(seq, batch, input_size)

out_bf,_ = gru_bf(in_bf)
out, _ = gru(input)


The output is:

torch.Size([1, 100, 1024])
torch.Size([100, 1, 1024])

It seems that if batch_first is true the output will also have batch first. Is this indeed the case or am I misinterpreting the output.

Thank you for your help I really appreciate it.