Batch dimension and Batch fist unstable behaviour

Hi all,

I’ve been suffering of batch size and batch_first parameters since a while and I realized that there is a caveat. When the batch size is 1 I have my features in the batch like this torch.Size([1, 1827]) but if i increase the batch size I get input as torch.Size([1223, 4]) which is not correct because each sentence is not 4 tokens long. Futhermore, I get error


When I do torch.transpose(input, 0,1).size() it seems better suit to my model torch.Size([4, 1223])

I have a standard dataloader object returns the items as below:

    def __getitem__(self, idx):
        # this should return one sample from the dataset
        features = self.x[idx]
        target = self.y[idx]
        lenght = self.lengths[idx]
        return (features, target, lenght)

My model is as below:

class SimpleLSTM(nn.Module):
    The RNN model 

    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int, num_classes: int, num_layers: int, dropout: float = 0.5, seq_len: int = 2024, pretrained_dim=None, fixed_pad=False):
        Initialize the model by setting up the layers.
        super(SimpleLSTM, self).__init__()
        DEVICE = get_device()
        self.pick_lstm_output = "last_layer"
        self.output_size = num_classes
        self.n_layers = num_layers
        self.hidden_dim = hidden_dim
        self.input_dim = embedding_dim
        # embedding and LSTM layers
        self.embedding = nn.Embedding(vocab_size, embedding_dim).to(DEVICE)
        # self.embedding.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype=torch.float32))
        # self.embedding.weight.requires_grad = False
        self.lstm = nn.LSTM(self.input_dim, hidden_dim, num_layers,
                            dropout=dropout, batch_first=True).to(DEVICE)
        # linear and sigmoid layer
        input_dim = hidden_dim
        if fixed_pad:
            input_dim = seq_len * hidden_dim
        self.fc = nn.Linear(input_dim, num_classes).to(DEVICE)


    def forward(self, batch):
        Perform a forward pass of our model on some input and hidden state.
        input, _, lengths = batch
        # embeddings and lstm_out
        embeds = self.embedding(input)
        packed_input = pack_padded_sequence(embeds,'cpu'), batch_first=True, enforce_sorted=False)
        packed_out, (h_n, c_n) = self.lstm(packed_input)
        lstm_output, output_lengths = pad_packed_sequence(packed_out)
        # stack up lstm outputs
        # lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
        # we pick the last output of the sequence as the output of the lstm layer for the classification task
        # Option 1 take the last output of lstm
        if self.pick_lstm_output == "last_layer":
            output = h_n[-1,:,:]  
        # Option 2 take the mean of the lstm output
        elif self.pick_lstm_output == "mean":
            output = lstm_output.mean(dim=1)
        # Option 3 take the max of the lstm output
        elif self.pick_lstm_output == "max":
            output = lstm_output.max(dim=1)[0]
            output = lstm_output.contiguous().view(input.shape[1], -1)
        out = self.fc(output)

        return out

Could you post a minimal and executable code snippet to reproduce the issue as I don’t know what the input is supposed to be?

Sorry, it’s very difficult to provide an example but I’ll try:

input: batches of 1d tensors which are coming from the dataloader. Each tensor is a tokenized/transformed text sequence.
batch_first = True
problem: When batch size 1 the input is dimension of Batch_SizexSequence_len
When the batch size is greater than 1 the input dimension is Sequence_lenxBatch_Size

I think the default behavour of the batch first here is reversed. That’s why I made a modification to my code as follows

swapped = input.swapaxes(0,1)
embeds = self.embedding(swapped)

Then the swapped version can fit to the rest of the code when the batch size is greater than 1 but this is a temporal solution because it’ll not work when the batch size is 1.

What do you mean the input is of shape (1223, 4)? The input of what? If you mean the input for the LSTM, the your data gets mangled before it’s given to the LSTM.