Training LSTM Nightmare

I wanted to train a char-lstm with the foll. parameters using one-hot encoding of the characters. My input was a 3D tensor of [batchsize x sequence length x one-hot dimension] and my target was also a 3D tensor of the same size.
To use cross entropy loss, I had to first flatten my input to a 2D tensor of [seq len*batchsize x one hot dim] and had to convert my target from one-hot to a suitable 1D tensor. I then had to use .squeeze() because .view() was not sufficient and I also had to typecast from FloatTensor to LongTensor. I hope these compatibility issues are fixed in the next release so that we can run cross-entropy loss on the inputs and targets just as they are.

Questions for you:

  1. Why couldn’t you just load target labels as indices? This will be more memory efficient and arguably run faster in calculating cross entropy loss.

Isn’t it nature that you would get LongTensor when converting one-hot to label indices? Every straightforward way to convert in PyTorch I can think of gives this property, e.g. torch.max.

.view can always do what .squeeze do.

You don’t need to “flatten” your input. No where in the doc says that, and RNN modules always only accepts 3D input I believe. Simply set batch_first=True should be fine.
Moreover, since data dim (i.e. one hot dim) is last dimension, shouldn’t it actually be very simple and straightforward? Something along the lines of

loss = criterion(out.view(-1, data_dim), labels.max(-1)[1].view(-1))

should just work.

Thanks for your reply! :slightly_smiling_face:
I agree that index labels are more memory efficient, but having different encoders for inputs and labels just made things slightly more difficult.

The flattening of my input was for Cross-Entropy Loss.

In your command, out.view() yields a [1, data_dim] tensor. This won’t pass through CE Loss. You need a .squeeze() to get a 1D tensor. At least that’s how it was for me :confused:

It will give you a [batch*seq, data_dim] tensor. CrossEntropyCriterion certainly requires input to be 2D.