I don't understand the meaning of batch_size in tensor

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embed = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden, cell):
        out = self.embed(x)
        out, (hidden, cell) = self.lstm(out.unsqueeze(1), (hidden, cell))
        out = self.fc(out.reshape(out.shape[0], -1))
        return out, (hidden, cell)

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        return hidden, cell
  1. Does setting batch_size mean using mini-batch to train the model?
  2. If so, batch_size = 1 means there is only one sample per batch, or is the gradient calculated once with all the data?
  3. Moreover, I am not sure why the hidden and cell in init_hidden also need to set batch_size?

Hello, LSTM in Pytorch needs batch_size for the input, hidden and cell. If you look at the documentation, they all have to have the following dimension.

  • input of shape (seq_len, batch_size, input_size):
  • hidden of shape (num_layers * num_directions, batch, hidden_size)
  • cell of shape (num_layers * num_directions, batch, hidden_size)

As you can see above, if the input is in batches, the hidden and cell also need to be in batches to account for the different batches in the input because each batch in the input will have its own hidden and cell.

Now, coming back to your first question. Yes setting batch_size is like mini-batch. Example if batch size is 3, then each of your input is a group of 3 sentences like I love Pytorch , I love Keras , I love NLP.

Setting batch_size = 1 means have just One sentence in each input.

Hope that helps.

1 Like

Thank you for your explanation :grinning: