Optimal RNN training architecture (pad/pack/collate/dataset/loader)

I’ve been working on training RNNs with PyTorch for some time now, and I’m wondering about the optimal architecture for doing so, using all the features that are available (padding, packing, etc.). I haven’t really found any tutorial or documentation that would provide a general framework of doing things properly from the beginning to the end.

The objectives are the following:

  • PyTorch-heavy (using all the natively supported features)
  • parallel DataLoader
  • CUDA-heavy (put as much of the work as possible to GPU)
  • computationally optimal (as a result of all the above)

The assumptions of the problem I’m solving (although I believe not crucial) are the following:

  • regression problem
  • training on all signals
  • predicting/evaluating last signal

I’ve come up with the following architecture

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import *
from torch.utils.data import Dataset, DataLoader

class RNNDataset(Dataset):
    def __init__(self, num_sequences, num_features, test=False):
        # random lengths from 1 to 100
        random_sequence_lengths = torch.randint(1, 100, (num_sequences,))
        # List of tensors, 1 element representing one sequence
        self.train = [torch.rand(l, num_features) for l in random_sequence_lengths]
        # List of sequence lengths
        self.lens = random_sequence_lengths
        self.len = len(self.train)

        self.test = test
        if test:
            # List of tensors, 1 element representing labels for one sequence
            self.labels = [torch.rand(l, 1) for l in random_sequence_lengths]
        else:
            self.labels = None

    def __getitem__(self, idx):
        if self.test:
            labels = self.labels[idx]
        else:
            labels = None

        return self.train[idx], self.lens[idx], labels

    def __len__(self):
        return self.len

class RNNNet(torch.nn.Module):
    def __init__(self, num_features):
        super(RNNNet, self).__init__()

        self.pre_rnn = nn.Linear(num_features, num_features * 2)
        self.rnn = nn.GRU(num_features * 2, num_features, batch_first=True)
        self.post_rnn = nn.Linear(num_features, 1)

    def forward(self, inputs, lens):
        # Indices for proper input order with regard to labels
        indices = torch.argsort(-lens)
        indices_back = torch.argsort(indices)

        # Pre RNN FC layer
        inputs_fc_pre = F.relu(self.pre_rnn(inputs))
        # Pack the sequence
        inputs_packed = pack_padded_sequence(inputs_fc_pre[indices], lens[indices], batch_first=True)
        # RNN layer
        inputs_rnn, _ = self.rnn(inputs_packed)
        # Reverse operation, pad the packed sequence
        inputs_rnn_padded, _ = pad_packed_sequence(inputs_rnn, batch_first=True)
        # Post RNN FC layer
        input_post_rnn = self.post_rnn(F.relu(inputs_rnn_padded))[indices_back]

        return input_post_rnn

    def predict(self, test, validation=False):
        # Predicting only the last output
        # Depending on whether it's a monitoring phase or an actual prediction, different preparation is needed
        if type(test) is DataLoader:
            test_loader = test
        else:
            test_loader = DataLoader(test, batch_size=100, shuffle=False,
                num_workers=2, collate_fn=rnn_collate, pin_memory=True)

        self.eval()
        device = next(model.parameters()).device
        results = []
        with torch.no_grad():
            for inputs_part, lens_part, labels in test_loader:
                results_chunk = self.forward(
                    inputs_part.to(device, non_blocking=True),
                    lens_part.to(device, non_blocking=True)
                )
                # Getting last output from test prediction
                last_preds = [x[l - 1] for x, l in zip(results_chunk, lens_part)]

                if validation:
                    last_labels = [x[l - 1] for x, l in zip(labels, lens_part)]
                    results.extend(zip(last_preds, last_labels))
                else:
                    results.extend(last_preds)

            results_final = torch.Tensor(results).detach()
        self.train()
        return results_final

def rnn_collate(batch):
    inputs, lens, labels = zip(*batch)
    # Pad the sequences and labels to equal length.
    # If processing test set, all labels will be None, in that case no processing is done.
    if labels[0] is not None:
        labels = pad_sequence(labels, batch_first=True, padding_value=-999)
    else:
        labels = None
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    lens = torch.LongTensor(lens)

    return inputs, lens, labels

def fit(train, validation, model, num_epochs, batch_size, learning_rate, device):
    criterion = nn.MSELoss()

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,
        num_workers=2, collate_fn=rnn_collate, pin_memory=True)
    valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=False,
        num_workers=2, collate_fn=rnn_collate, pin_memory=True)

    model.zero_grad()
    model.train()
    model.to(device, non_blocking=True)

    optimizer = torch.optim.Adam(model.parameters(), learning_rate)

    for epoch in range(num_epochs):
        running_loss = 0.0

        for i, data in enumerate(train_loader):
            # Get the inputs
            inputs, lens, labels = data
            labels = labels.to(device, non_blocking=True)
            inputs = inputs.to(device, non_blocking=True)
            lens = lens.to(device, non_blocking=True)

            # Zero the parameter gradients
            optimizer.zero_grad()
            # Forward + backward + optimize
            outputs = model.forward(inputs, lens)
            # Learning on all time steps outputs - need a mask to extract the relevant ones
            labels_mask = (labels != -999)
            # Extract outputs of interest
            outputs_extracted = outputs[labels_mask]

            # Get the RMSE
            loss = torch.sqrt(criterion(outputs_extracted, labels[labels_mask]))
            loss.backward()

            optimizer.step()

            # Gather statistics
            running_loss += loss.item()

        results = model.predict(valid, True)

        test_loss = torch.sqrt(F.mse_loss(results[:, 0], results[:, 1])).item()

        print(f"Train loss after epoch {epoch}: {running_loss / i}")
        print(f"Test loss after epoch {epoch}: {test_loss}")

        # Reset the loss monitor
        running_loss = 0.0

    return model

if __name__ == "__main__":
    num_sequences = 1000
    num_features = 50
    device = "cuda"
    num_epochs = 10
    batch_size = 64
    learning_rate = 1e-3

    train = RNNDataset(num_sequences, num_features, test=True)
    valid = RNNDataset(num_sequences, num_features, test=True)
    test = RNNDataset(num_sequences, num_features, test=False)
    model = RNNNet(num_features)

    model = fit(train, valid, model, num_epochs, batch_size, learning_rate, device)

    preds = model.predict(test)
    print(preds[:10])

It works pretty nicely and fast in general. The main problem being that padding of the batch is done on CPU and not on GPU in rnn_collate. I tried sending the batch elements to GPU in rnn_collate or directly in RNNDataset (changing the multiprocessing to spawn to avoid CUDA initialization errors) but I was getting some shared memory block exceptions (with memory shared block already increased), so I assumed the approach is not optimal.

My questions are the following:

  • Is there some tutorial/documentation on how to create an optimal (or close to) RNN training architecture in PyTorch?
  • Do you have any hints regarding the above code that would make it faster/more optimal/more PyTorch-heavy?

All suggestions welcome!

Is bumping allowed on this forum?
Anyone has an opinion on the topic?

I’m not exactly sure whether my method does the padding on gpu or not, but what I used to do was to pass a list of gpu tensors to the model and pad them inside the model. Also, I feel like that the model would be faster if u do not use a PackedSequence as input; instead, u could use a padded tensor as the input(my feeling came from the “note” under the LSTM part in the doc).

Thanks for your input.

  1. I tried sending tensors to GPU in collate_fn or in my custom Dataset, but this causes problems (described in the first post at the bottom) when I use num_workers > 0. It seem to run out of shared memory too easily. Did you have problems like that? How did you get around this?
  2. I haven’t seen that part of documentation, although it also mentions the data should be a float16 which is not my case (seems a bit too imprecise to me).
  1. https://pytorch.org/docs/stable/data.html#multi-process-data-loading Multi process data loading to GPU is not suggested. You can use multiple process to load them and pass them to GPU together.
  2. I’m not exactly sure whether that helps but according to my experience it helps a little.

My particular problem with 1. is that I can’t really see a good spot to send the tensors to GPU before doing pad_sequence. I want to return a Tensor from collate_fn (if I return a list I again get strange behaviour with num_workers > 0), so it seems to me (at least for now) I need to pad in collate_fn. On the other hand, I can’t/shouldn’t send Tensors to GPU in collate_fn to avoid problems.
You mind sharing a piece of code of how are you doing it? Could be easier than explaining it in text.

I can’t find the exact code, I’ll just show a sample.

def collate_fn(batch: List[Tuple[Tensor, int]]):
    inps, labels = zip(*data)
    return list(map(lambda x: x.pin_memory(), inps), tc.tensor(labels).pin_memory()
def train(...):
    for inp, labels in LOADER:
        for i in inp:
            i = i.to(device=DEVICE, non_blocking = True)
        labels = labels.to(device=DEVICE, non_blocking = True)
        OPTIM.zero_grad()
        ...

My code should be something like this, and I pad the tensors inside the model.

I get it, thanks a lot! Didn’t really think about the pin_memory method. I didn’t had much success with using num_workers > 0 and returning anything else than an iterable of tensors, but maybe the problem lies somewhere else.