I’ve been working on training RNNs with PyTorch for some time now, and I’m wondering about the optimal architecture for doing so, using all the features that are available (padding, packing, etc.). I haven’t really found any tutorial or documentation that would provide a general framework of doing things properly from the beginning to the end.
The objectives are the following:
- PyTorch-heavy (using all the natively supported features)
- parallel DataLoader
- CUDA-heavy (put as much of the work as possible to GPU)
- computationally optimal (as a result of all the above)
The assumptions of the problem I’m solving (although I believe not crucial) are the following:
- regression problem
- training on all signals
- predicting/evaluating last signal
I’ve come up with the following architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import *
from torch.utils.data import Dataset, DataLoader
class RNNDataset(Dataset):
def __init__(self, num_sequences, num_features, test=False):
# random lengths from 1 to 100
random_sequence_lengths = torch.randint(1, 100, (num_sequences,))
# List of tensors, 1 element representing one sequence
self.train = [torch.rand(l, num_features) for l in random_sequence_lengths]
# List of sequence lengths
self.lens = random_sequence_lengths
self.len = len(self.train)
self.test = test
if test:
# List of tensors, 1 element representing labels for one sequence
self.labels = [torch.rand(l, 1) for l in random_sequence_lengths]
else:
self.labels = None
def __getitem__(self, idx):
if self.test:
labels = self.labels[idx]
else:
labels = None
return self.train[idx], self.lens[idx], labels
def __len__(self):
return self.len
class RNNNet(torch.nn.Module):
def __init__(self, num_features):
super(RNNNet, self).__init__()
self.pre_rnn = nn.Linear(num_features, num_features * 2)
self.rnn = nn.GRU(num_features * 2, num_features, batch_first=True)
self.post_rnn = nn.Linear(num_features, 1)
def forward(self, inputs, lens):
# Indices for proper input order with regard to labels
indices = torch.argsort(-lens)
indices_back = torch.argsort(indices)
# Pre RNN FC layer
inputs_fc_pre = F.relu(self.pre_rnn(inputs))
# Pack the sequence
inputs_packed = pack_padded_sequence(inputs_fc_pre[indices], lens[indices], batch_first=True)
# RNN layer
inputs_rnn, _ = self.rnn(inputs_packed)
# Reverse operation, pad the packed sequence
inputs_rnn_padded, _ = pad_packed_sequence(inputs_rnn, batch_first=True)
# Post RNN FC layer
input_post_rnn = self.post_rnn(F.relu(inputs_rnn_padded))[indices_back]
return input_post_rnn
def predict(self, test, validation=False):
# Predicting only the last output
# Depending on whether it's a monitoring phase or an actual prediction, different preparation is needed
if type(test) is DataLoader:
test_loader = test
else:
test_loader = DataLoader(test, batch_size=100, shuffle=False,
num_workers=2, collate_fn=rnn_collate, pin_memory=True)
self.eval()
device = next(model.parameters()).device
results = []
with torch.no_grad():
for inputs_part, lens_part, labels in test_loader:
results_chunk = self.forward(
inputs_part.to(device, non_blocking=True),
lens_part.to(device, non_blocking=True)
)
# Getting last output from test prediction
last_preds = [x[l - 1] for x, l in zip(results_chunk, lens_part)]
if validation:
last_labels = [x[l - 1] for x, l in zip(labels, lens_part)]
results.extend(zip(last_preds, last_labels))
else:
results.extend(last_preds)
results_final = torch.Tensor(results).detach()
self.train()
return results_final
def rnn_collate(batch):
inputs, lens, labels = zip(*batch)
# Pad the sequences and labels to equal length.
# If processing test set, all labels will be None, in that case no processing is done.
if labels[0] is not None:
labels = pad_sequence(labels, batch_first=True, padding_value=-999)
else:
labels = None
inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
lens = torch.LongTensor(lens)
return inputs, lens, labels
def fit(train, validation, model, num_epochs, batch_size, learning_rate, device):
criterion = nn.MSELoss()
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,
num_workers=2, collate_fn=rnn_collate, pin_memory=True)
valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=False,
num_workers=2, collate_fn=rnn_collate, pin_memory=True)
model.zero_grad()
model.train()
model.to(device, non_blocking=True)
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(train_loader):
# Get the inputs
inputs, lens, labels = data
labels = labels.to(device, non_blocking=True)
inputs = inputs.to(device, non_blocking=True)
lens = lens.to(device, non_blocking=True)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward + backward + optimize
outputs = model.forward(inputs, lens)
# Learning on all time steps outputs - need a mask to extract the relevant ones
labels_mask = (labels != -999)
# Extract outputs of interest
outputs_extracted = outputs[labels_mask]
# Get the RMSE
loss = torch.sqrt(criterion(outputs_extracted, labels[labels_mask]))
loss.backward()
optimizer.step()
# Gather statistics
running_loss += loss.item()
results = model.predict(valid, True)
test_loss = torch.sqrt(F.mse_loss(results[:, 0], results[:, 1])).item()
print(f"Train loss after epoch {epoch}: {running_loss / i}")
print(f"Test loss after epoch {epoch}: {test_loss}")
# Reset the loss monitor
running_loss = 0.0
return model
if __name__ == "__main__":
num_sequences = 1000
num_features = 50
device = "cuda"
num_epochs = 10
batch_size = 64
learning_rate = 1e-3
train = RNNDataset(num_sequences, num_features, test=True)
valid = RNNDataset(num_sequences, num_features, test=True)
test = RNNDataset(num_sequences, num_features, test=False)
model = RNNNet(num_features)
model = fit(train, valid, model, num_epochs, batch_size, learning_rate, device)
preds = model.predict(test)
print(preds[:10])
It works pretty nicely and fast in general. The main problem being that padding of the batch is done on CPU and not on GPU in rnn_collate
. I tried sending the batch elements to GPU in rnn_collate
or directly in RNNDataset
(changing the multiprocessing to spawn
to avoid CUDA initialization errors) but I was getting some shared memory block exceptions (with memory shared block already increased), so I assumed the approach is not optimal.
My questions are the following:
- Is there some tutorial/documentation on how to create an optimal (or close to) RNN training architecture in PyTorch?
- Do you have any hints regarding the above code that would make it faster/more optimal/more PyTorch-heavy?
All suggestions welcome!