Hello,
In the 60 minutes blitz tutorial, it is written that:
torch.nn only supports mini-batches. The entire torch.nn package only supports inputs that are a mini-batch of samples, and not a single sample. For example, nn.Conv2d will take in a 4D Tensor of nSamples x nChannels x Height x Width . If you have a single sample, just use input.unsqueeze(0) to add a fake batch dimension.
However, in the RNN classification tutorial, the batch size is in the first dimension:
To make a word we join a bunch of those into a 2D matrix <line_length x 1 x n_letters> . That extra 1 dimension is because PyTorch assumes everything is in batches - we’re just using a batch size of 1 here.
But when looking at the documentation of nn.linear, it seems that the batch size (N?) is in the first position, as in Conv2d.
Have I misunderstood something, or is there an error in the RNN tutorial ?
Thanks in advance for your insights.
1 Like
You are correct in your understanding and as described in the docs, the input is expected as [seq_len, batch, input_size] in the default setup.
If you want the batch dimension to be in dim0, you can specify batch_first=True.
As far as I know, this might however lead to worse performance if you are using the GPU and cuDNN.
Thank you. I agree for nn.RNN, but the tutorial is implementing a RNN from scratch, using nn.Linear, and the latter has batch dimension first (contrary to what the tutorial says), no ?
Yes, nn.Linear expects the batch dimension to be in dim0.
However, in the tutorial you are either providing the input letter by letter (thus the seq_len is missing) or you are slicing the encoded word tensor, so that again the seq_len dimension will be missing.
Here are the example cells:
input = letterToTensor('A')
hidden =torch.zeros(1, n_hidden)
print(input.shape)
> torch.Size([1, 57])
input = lineToTensor('Albert')
hidden = torch.zeros(1, n_hidden)
output, next_hidden = rnn(input[0], hidden)
print(input[0].shape)
> torch.Size([1, 57])
As you can see, you are passing only [batch_size, input_size] to your model.
2 Likes