RNN loss with packed/padded input

Hi,

I am training an RNN model with variable sized input sequences. To speed up training, I construct mini-batches of the sequences using a custom collate_fn passed to a DataLoader object, where each input to this function is a (batched) tuple of (sequence, target) pairs (according to __getitem__ from a Dataset class):

def collate_fn(batch):                                                                                                                                                                                                                                                                                                                                                                                                              
    # batch[0] = features
    # batch[1] = labels                                                                                                                                                                                                                                                                                                                                                                        
    batch_size = len(batch)                                                                                                                                                                                                              
    data_lens = torch.tensor([seq[0].shape[0] for seq in batch])                                                                                                                                                                         
    data_lens_sorted, perm_idx = torch.sort(data_lens, descending=True)                                                                                                                                                                  
    # extract features and labels
    features = [torch.Tensor(batch[i][0]) for i in range(batch_size)]                                                                                                                                                                    
    labels = [torch.Tensor(batch[j][1]) for j in range(batch_size)]                                                                                                                                                                      
    # pad and then order the features and labels according to data_lens_sorted                                                                                                                                                                        
    pad_features = pad_sequence(features, batch_first=True)                                                                                                                                                                              
    pad_labels = pad_sequence(labels, batch_first=True)                                                                                                                                                                                  
    sorted_pad_features = pad_features[perm_idx]                                                                                                                                                                                         
    sorted_pad_labels = pad_labels[perm_idx]   
    # pack the padded features and labels                                                                                                                                                                                                                                                                                                                                                                                                 
    packed_features = pack_padded_sequence(sorted_pad_features, data_lens_sorted, 
    batch_first=True, enforce_sorted=True)                                                                                                                 
    packed_labels = pack_padded_sequence(sorted_pad_labels, data_lens_sorted,     
    batch_first=True, enforce_sorted=True)                                                                                                                                                                                                                                                                                                                                   
    return packed_features, pad_labels, data_lens_sorted

As shown, the collate_fn returns a packed representation of the input data (packed_features), padded labels (pad_labels), and a tensor of sequence lengths in descending order (data_lens_sorted). I chose to return packed_features so I don’t have to deal with packing during the training loop, and padded labels so that the loss can be calculated with minimal transformations (since the packed_features will later be unpacked to its padded representation after passing through the RNN).

My model is constructed as follows:

class LSTM(nn.Module):                                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                                            
    def __init__(self, input_size, hidden_size, output_size, num_layers):                                                                                                                                                                                                   
        super(LSTM, self).__init__()                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)      
        self.fcl = nn.Sequential(                                                                                                                                                                                                                                           
                           nn.Linear(hidden_size, 128),                                                                                                                                                                                                                             
                           nn.PReLU(),                                                                                                                                                                                                                                              
                           nn.Linear(128, output_size),)                                                                                                                                                                                                                      

    def forward(self, input_data, h0=None)
        rnn_out, (h0) = self.gru(state_vector, (h0))
        pad_out, seq_lens = pad_packed_sequence(rnn_out, batch_first=True)
        flat_pad_out = pad_out.flatten(start_dim=0, end_dim=1) # is this right?
        fcl_out = self.fcl(flat_pad_out)
        return fcl_out

And my training loop is as follows:

for batch in train_dl:                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
    data = batch[0].to(device)                                                                                                                                                                                                                                 
    labels = batch[1].data.to(device) # extract data attribute of padded_sequence                                                                                                                                                                                                                        
    seq_lens = batch[2].to(device)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
    pred = model(data.float())                                                                                                                                                                                                                                                                                                                                                                                                                                                        
    optim.zero_grad() # zero the gradients after each mini-batch                                                                                                                                                                                               
    labels = labels.flatten(start_dim=0, end_dim=1)  # flatten to match output of forward()                                                                                                                                                                                                         
    non_zero_idxs = torch.abs(labels).sum(dim=1) > 0  # mask for non-zero inputs                                                                                                                                                                                                         
    labels = labels[non_zero_idxs]                                                                                                                                                                                                                             
    pred = pred[non_zero_idxs]                                                                                                                                                                                                                                                                                                                                                                                                                                                        
    loss = loss_f(pred, labels)                                                                                                                                                                                                                                                                                                                                                                                                                   
    loss.backward() # compute the gradients                                                                                                                                                                                                                                                                                                                                                       
    optim.step() # update the parameters

Using this setup, I am able to achieve decent performance (decreasing training loss) when using a batch size of only 1. When I use a batch size greater than one, the training loss exhibits instability, and does not seem to decrease as one would normally expect. This makes me suspicious of the approach I am using to batchify my inputs, namely the collate_fn. Can someone please verify whether my batching method seems correct?

Another doubt I have is with respect to the calculation of the loss. As shown in my RNN forward function, I pad the packed output representation from the RNN and reshape it to be compatible with the subsequent FC layer. Instead of re-padding the outputs using pad_packed_sequence, can I just use the packed output representation and feed it directly to the FC layer? This will avoid me having to also mask the padded_labels and will allow me to use packed_labels (instead of padded_labels) directly to calculate the loss with the packed_output from the FC layer. Are either of these approaches even the correct way to be doing this?

I would appreciate any advice.

Maybe this older post might be interesting. The idea is to create a Sampler to ensure that all samples withing a batch have the same combination of input and target lengths.