Understanding encoder of Seq2Seq model

I am currently trying to implement a Seq2Seq model for machine translation task. I was following the an official Pytorch tutorial.
Tutorial Link

My theoretical understanding for the encoder was that it takes a single input at each timestamps and generates a new hidden state and the process continues i.e. taking the new word and previous hidden state. But in the tutorial they are passing all the input together to the encoder. How is the encoder able to understand the relationship between words if we do this?

I have attached the code snippet below

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden


def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion):

    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)      # This line here

        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

The nn.GRU class does this all under the hood – that is, it takes the whole input sequence and iterates over each sequence item. In other words, the loop you are looking for is implemented by nn.GRU.

In case you want to implement this loop yourself, because you need to do some customization, you can have a look at nn.GRUCell.

1 Like