I’m using sequence data for input into a LSTMCell. As per the LSTMCell documentation the input shape should be (time_index, batch, n_features). I reshaped my training data into a similar shape of (time_index, n_sequences_in_training_set, n_features) for the input data X_train and (time_index, n_sequences_in_training_set, 1) for the target data y_train.
When I iterate over a data loader created with a data it batches over the time_index (1st) dimension which is what I would expect. However, in this case I need to batch over the second dimension and I can’t figure how to do that. I have read that this can be done using a custom collate_fn, but when I looked at this it appeared that the data is already passed batched along the first dimension into the collate_fn (is this correct?). Is it it possible to return batched data over the second dimension?
I believe the input to
collate_fn is a list of Tensors. You should be able to have a custom collate function to do what you would like.
Thanks Kevin, what you described helped me solve the problem I was seeing. The issue was actually in the CustomDataset class I was using and specifically in the getitem method as this was passing the list of Tensors indexed over the first dimension into the collate_fn (Shows why I should have posted a code snippe!t). Below is a code snipped showing the solution I came to.
I wonder if it is possible to stack the Tensors more efficiently in the for loop in the collate_fn, but I couldn’t figure a way of doing this without unpacking the X and y tensors first using the loop.
from torch.utils.data import Dataset, DataLoader
import numpy as np
def __init__(self, X, y):
self.X = X
self.y = y
def __getitem__(self, idx):
X = self.X[:, idx, :]
y = self.y[:, idx, :]
return X, y
X_data = 
y_data = 
for seq in data_in:
X_data = torch.stack(X_data, dim=1)
y_data = torch.stack(y_data, dim=1)
return X_data, y_data
# X shape = (time_step, n_samples, n_features)
# y shape = (time_step, n_samples, 1)
X = torch.Tensor(np.zeros((4, 10, 6)))
y = torch.Tensor(np.zeros((4, 10, 1)))
for i in range(X.shape):
X[i, :, :] = i
y[i, :, :] = i
dataset = CustomDataSet(X=X, y=y)
loader = DataLoader(dataset, batch_size=3, shuffle=False, num_workers=1, collate_fn=my_collate)
for i, batch in enumerate(loader):
Glad you are able to resolve the issue! It might be marginally faster if you use list comprehension instead of
.append but it shouldn’t matter too much.