RNNs: difference between batching strategies. Naive Batching vs. Pytorch Batching

Hey all

I am currently experimenting with different LSTM models for time-series prediction and am noticing a strange behaviour difference between two implementations that should be (from my understanding) fundamentally the same.

For my first implementation I used two LSTMCells and avoided pytorch’s built in batching. I initially did this because my training data has variable sequence lengths and I was averse to dealing with the variable input length when it comes to batching, which is why I kind of cheated with the batch dimension in the forward function. Thus to handle batches, (as shown below) I run each sample in the batch through the network individually, calculate the loss and the associated gradients, caching the gradients and updating the network using the accumulated gradients.

For my second model I tried to utilize pytorch optimizations and instead of LSTMCell I used actual LSTM layers and the built-in batching, which should result in a considerable speed up. Note that for this model, I always use num_layers=2.

I then tried to compare how these two models train and found that for implementation 1, the model exhibits two phases, one where it quickly decreases, and another where it has noisy oscillations around a slowly decreasing average (much like an L shape). For model 2 I expected the same behaviour but found that while it also has the same initial phase where it quickly decreases and leads into the same second phase seen in implementation 1, once in this phase the MSELoss occasionally spikes up really high, and then the model enters another phase where it quickly decreases until it levels out again.

Is there any obvious reason why I may be seeing this difference? For both of these models I am using the Adam optimizer with its default parameters, and torch.nn.MSELoss() for the loss function. Ultimately I’d like to use Implementation 2 because it does lead to a considerable speed up in the training process, but I’m concerned that I am doing something wrong.

Update: I’ve realized that since I am not dividing the loss function by the batch size for implementation 1, the gradients may differ significantly in magnitude but this would mean that I’d have larger gradients for the first implementation, would it not? Meaning that if I was to see spiking behaviour, it should be for Imp 1 instead of Imp 2.

Implementation 1: using naive batching


class PyLSTMCells(nn.Module):
    # init layers
    def __init__(self, input_size, hidden_size, output_size):
        super(PyLSTMCells, self).__init__()

        # attach hidden layer size to class
        self.hidden_size = hidden_size

        # define layers
        self.l1 = nn.LSTMCell(input_size, hidden_size)
        self.l2 = nn.LSTMCell(hidden_size, hidden_size)
        self.o  = nn.Linear(hidden_size, output_size)

    # define forward pass over single time-series
    def forward(self, input):
        # get tensor type of input
        Ttype = input.type()

        # add an additional "first" dimension to work
        # with the LSTMCell interface
        pad_input = input.unsqueeze(0)
        # define hidden/cell states
        h1 = torch.zeros(1, self.hidden_size).type(Ttype)
        c1 = torch.zeros(1, self.hidden_size).type(Ttype)
        h2 = torch.zeros(1, self.hidden_size).type(Ttype)
        c2 = torch.zeros(1, self.hidden_size).type(Ttype)

        # loop over time series and get final output
        TS_len = pad_input.shape[1]
        for i in range(TS_len):
            h1, c1 = self.l1(pad_input[:,:,i], (h1,c1))
            h2, c2 = self.l2(h1, (h2,c2))
            output = self.o(h2)
        # return output, ignoring the false "first" dim
        # that was added
        return output[0,:]

Training Method

opt.zero_grad() # Adam optimizer with default parameters
batch_loss = 0
for i in range(batch_size):
    X, lbl = get_example()
    output = model.forward(X) # where model is the network defined above
    l = loss(output, lbl)     # MSELoss
    batch_loss += l
avg_batch_loss = batch_loss/batch_size

Implementation 2: LSTM layers and Pytorch batching


class BatchPyLSTM(nn.Module):
    # init network
    def __init__(self, input_size, hidden_size, output_size, num_layers=2):
        super(BatchPyLSTM, self).__init__()

        # attach important vars to class
        self.input_size     = input_size
        self.hidden_size    = hidden_size
        self.num_layers     = num_layers

        # define layers
        self.lstm   = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.lin    = nn.Linear(hidden_size, output_size)

    # init hidden and cell states
    def init_hidden(self, Ttype, batch_size):
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).type(Ttype)  # hidden states
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).type(Ttype)  # cell states
        return h0,c0 

    # forward pass over entire batch
    def forward(self, input, seq_lens):
        # get tensor type of input and batch size
        Ttype       = input.type()
        batch_size  = input.shape[0]

        # init hidden and cell states
        h0, c0  = self.init_hidden(Ttype, batch_size)

        # reshape before packing to match pytorch interface
        pckd_input = input.permute(0,2,1) 

        # pack input so the LSTM knows now to consider the padded values
        pckd_input = nn.utils.rnn.pack_padded_sequence(pckd_input, seq_lens, 
        # feed through LSTM
        lstm_otp,_ = self.lstm(pckd_input,(h0,c0))

        # repad lstm_otp into normal tensor object so we can actually work with it
        lstm_otp,_ = nn.utils.rnn.pad_packed_sequence(lstm_otp, batch_first=True,

        # extract lstm output for final time-step of each series
        fin_lstm_otp = torch.stack([ o[l-1,:] for o,l in zip(lstm_otp, seq_lens) ])

        # feed lstm output tensors 
        otp = self.lin(fin_lstm_otp)
        return otp

Training Method

opt.zero_grad()                         # Adam optimizer with default parameters
X_batch, lbl_batch, seq_lens = get_batch(batch_size)
output = model.forward()                              # implementation 2 network
avg_batch_loss = loss(output, lbl_batch)              # MSELoss