Batch size position and RNN tutorial

In the 60 minutes blitz tutorial, it is written that:

torch.nn only supports mini-batches. The entire torch.nn package only supports inputs that are a mini-batch of samples, and not a single sample. For example, nn.Conv2d will take in a 4D Tensor of nSamples x nChannels x Height x Width . If you have a single sample, just use input.unsqueeze(0) to add a fake batch dimension.

However, in the RNN classification tutorial, the batch size is in the first dimension:

To make a word we join a bunch of those into a 2D matrix <line_length x 1 x n_letters> . That extra 1 dimension is because PyTorch assumes everything is in batches - we’re just using a batch size of 1 here.

But when looking at the documentation of nn.linear, it seems that the batch size (N?) is in the first position, as in Conv2d.

Have I misunderstood something, or is there an error in the RNN tutorial ?

Thanks in advance for your insights.

You are correct in your understanding and as described in the docs, the input is expected as [seq_len, batch, input_size] in the default setup.
If you want the batch dimension to be in dim0, you can specify batch_first=True.
As far as I know, this might however lead to worse performance if you are using the GPU and cuDNN.

Thank you. I agree for nn.RNN, but the tutorial is implementing a RNN from scratch, using nn.Linear, and the latter has batch dimension first (contrary to what the tutorial says), no ?

Yes, nn.Linear expects the batch dimension to be in dim0.
However, in the tutorial you are either providing the input letter by letter (thus the seq_len is missing) or you are slicing the encoded word tensor, so that again the seq_len dimension will be missing.

Here are the example cells:

input = letterToTensor('A')
hidden =torch.zeros(1, n_hidden)
> torch.Size([1, 57])

input = lineToTensor('Albert')
hidden = torch.zeros(1, n_hidden)

output, next_hidden = rnn(input[0], hidden)
> torch.Size([1, 57])

As you can see, you are passing only [batch_size, input_size] to your model.