LSTM with pad_packed_sequence

Hi,
I’m using PyTorch to create an LSTM autoencoder that receives a 1D input time series and outputs the reconstruction of the timeserie. The model takes as input sequences of variable length considering one timestep at time.
This is the model:

class packetAE(nn.Module):

    def __init__(self, lstm1_h: int, ):
        super().__init__()

        self.encoder = nn.LSTM(1, lstm1_h, 1, batch_first=True)
        self.decoder = nn.LSTM(lstm1_h, 1, 1, batch_first=True)

    def forward(self, input):
        encoding = self.encoder(input)
        return self.decoder(encoding)

This is how I create the dataloader and fed it to the network:

def collate_padding(batch):
    (xx, yy) = zip(*batch)
    
    # retrieve original lenght of batch sequences
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    # add padding
    xx_pad = torch.nn.utils.rnn.pad_sequence(xx, batch_first=True, padding_value=0)
    yy_pad = torch.nn.utils.rnn.pad_sequence(yy, batch_first=True, padding_value=0)

    return xx_pad, yy_pad, x_lens, y_lens

train_loader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=16, pin_memory=True, collate_fn=collate_padding)

model = packetAE(128)

for batch_index, (X, y, x_len, y_len) in enumerate(train_dataloader):
    X_packed = torch.nn.utils.rnn.pack_padded_sequence(X, x_len, batch_first=True, enforce_sorted=False)
    pred_packed = model(X_packed)

However, when running the training code, I got the following error when passing the input to the encoder stage of the model for training:

RuntimeError: input must have 2 dimensions, got 1

Anyone has idea of the reason?

The dataset contains items like:

tensor([184, 172, 111,  54,  10, 139,   0, 193, 177,  20, 235,  49,   8,   0,
         69,   0,   0,  40,  27,  52,   0,   0,  55,   6, 154, 114,   8, 254,
        250, 126, 192, 168,  10,   5,   0,  80, 192,  36, 223, 215, 173, 139,
        148, 184,  56, 155,  80,  17,   1,  73, 197,  52,   0,   0,   0,   0,
          0,   0,   0,   0], dtype=torch.uint8)