What is happening in this forward function?


I can’t understand what’s happening at the last two lines of the forward method in this model (I removed irrelevant code):

class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)


        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

I’ll appreciate shedding some light on this. Thanks :slight_smile:

The self.decoder here is just the linear layer to goes from the hidden units to the output size.
And to decode all the outputs for each time step and each element in that batch, it reshape the output tensor that was (where x is the separator for different dimensions) batch x nb_steps x hidden_size to batch*nb_steps x hidden_size.
The second view then reshapes again the decoded version from batch*nb_steps x ntoken to batch x nb_steps x ntoken.

Note that with recent version of pytorch, this is not necessary as Linear will accept more than 2D tensors and will consider all dimensions but the last one as being batch dimensions.

