Hello,
In the 60 minutes blitz tutorial, it is written that:
torch.nn
only supports mini-batches. The entire torch.nn
package only supports inputs that are a mini-batch of samples, and not a single sample. For example, nn.Conv2d
will take in a 4D Tensor of nSamples x nChannels x Height x Width
. If you have a single sample, just use input.unsqueeze(0)
to add a fake batch dimension.
However, in the RNN classification tutorial, the batch size is in the first dimension:
To make a word we join a bunch of those into a 2D matrix <line_length x 1 x n_letters>
. That extra 1 dimension is because PyTorch assumes everything is in batches - we’re just using a batch size of 1 here.
But when looking at the documentation of nn.linear, it seems that the batch size (N?) is in the first position, as in Conv2d.
Have I misunderstood something, or is there an error in the RNN tutorial ?
Thanks in advance for your insights.
1 Like
You are correct in your understanding and as described in the docs, the input is expected as [seq_len, batch, input_size]
in the default setup.
If you want the batch dimension to be in dim0, you can specify batch_first=True
.
As far as I know, this might however lead to worse performance if you are using the GPU and cuDNN.
Thank you. I agree for nn.RNN, but the tutorial is implementing a RNN from scratch, using nn.Linear, and the latter has batch dimension first (contrary to what the tutorial says), no ?
Yes, nn.Linear
expects the batch dimension to be in dim0.
However, in the tutorial you are either providing the input letter by letter (thus the seq_len
is missing) or you are slicing the encoded word tensor, so that again the seq_len
dimension will be missing.
Here are the example cells:
input = letterToTensor('A')
hidden =torch.zeros(1, n_hidden)
print(input.shape)
> torch.Size([1, 57])
input = lineToTensor('Albert')
hidden = torch.zeros(1, n_hidden)
output, next_hidden = rnn(input[0], hidden)
print(input[0].shape)
> torch.Size([1, 57])
As you can see, you are passing only [batch_size, input_size]
to your model.
2 Likes