RNN/LSTMs: Feeding in a single time-step or an entire sequence?

In a task I would like to accumulate losses computed from each time-step of a sequence, I can do the following two ways, which I supposed to be equivalent (but seems not):

  1. Feeding in a single time-step and compute a loss for it, until a sequence is consumed:
# note that lstm is initialized with batch_first = True
# the model structure is the common embedding -> lstm -> linear -> softmax
hidden = model.initial_hidden()
loss = 0 # initialize loss for a batch of sequences

for t in range(seq_length):
    output, hidden = model(input[:, t, :], hidden)
    loss += loss_function(output, y[:, t]) # accumulate loss per time-step

loss.backward()

input.size() -> (batch_size, seq_length, input_size); and
output.size() -> (batch_size, n_label), is the output of softmax(); and
y[:,t].size() -> (batch_size), loss_function is NLLoss().

  1. Feeding in a whole sequence, and accumulate losses in a for loop
hidden = model.initial_hidden()
output, hidden = model(input, hidden)
loss = 0

for t in range(seq_length):
    loss += loss_function(output[:, t, :], y[:, t])

loss.backward()

in this setting output.size() -> (batch_size, seq_length, n_label)

So my question is that whether these two settings are equivalent. In my experiment, setting one is four times slower than setting two, but the accumulated loss is about two times lower (the observation of the first two epochs). Could someone help me out on indicating the essential differences between the two settings, if any?

1 Like

Why don’t you compute the loss after having supplied the entire sequence?

NLLoss() doesn’t know about seq_length, and the following would cause an error:

hidden = model.initial_hidden()
output, hidden = model(input, hidden)
loss = loss_function(output, y)

output.size() -> (batch_size, seq_length, n_label); and
y.size -> (batch_size, seq_length)

The error: ValueError: Expected target size (batch_size, n_label), got torch.Size([batch_size, seq_length])

So in order to calculate losses for each time-step and add them up, I came up with setting one and two.