How to train a many to many LSTM with BPTT


(Sean Chen) #1

Hi everyone, I am learning LSTM. I try official LSTM example as follows:

for epoch in range(300):  # again, normally you would NOT do 300 epochs, it is toy data
    for sentence, tags in training_data:
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        model.zero_grad()

        # Also, we need to clear out the hidden state of the LSTM,
        # detaching it from its history on the last instance.
        model.hidden = model.init_hidden()

        # Step 2. Get our inputs ready for the network, that is, turn them into
        # Variables of word indices.
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)

        # Step 3. Run our forward pass.
        tag_scores = model(sentence_in)

        # Step 4. Compute the loss, gradients, and update the parameters by
        #  calling optimizer.step()
        loss = loss_function(tag_scores, targets)
        loss.backward()
        optimizer.step()

However, I have a question about the backpropagation:

loss = loss_function(tag_scores, targets)
loss.backward()
optimizer.step()

These three code lines seem to have nothing to do with sequence step, but I think LSTM needs to be trained with BPTT. Could you tell me the reason? And moreover, I also wonder when BPTT should be applied and how to realize BPTT with Pytorch? Thank you in advance.


(jpeg729) #2

sequence_in contains many timesteps, right? Therefore backpropagation goes right back to the start of sequence_in.

This example does do BPTT without truncation.


(Lev E Givon) #3

Doesn’t the initialization of the hidden state at each iteration over the data in the quoted example effectively result in truncated BPTT (compare this posting)?


(jpeg729) #4

Yes, but doesn’t each batch contain entire sentences?


(Lev E Givon) #5

Ah, yes - that’s right.


(Sean Chen) #6

Thank you very much!


(Sean Chen) #7

So, if I want to do BPTT with truncation for this example, could you tell me how to modify the code?


(jpeg729) #8

If each data sample is just one sentence, then it already does as much bptt as it needs.


(Sean Chen) #9

Thank you, but I should express more clearly. This example has done BPTT for a sentence without truncation. However, could you tell me how to realize a BPTT with truncation for this example?


(jpeg729) #10

To do BPTT with truncation you would need to cut the input into subsequences and train on each subsequence, separately but in order. The subsequences need to be fed to the model in order so that the last hidden state from the end of each subsequence is used at the beginning of the next subsequence.

The data input is sentence_in and the corresponding targets are targets.
You would have to split sentence_in and target into corresponding chunks along the sequence dimension.

sentence_parts = torch.split(sentence_in, chunk_size, dim=0)
targets_parts = torch.split(targets, chunk_size, dim=0)

then instead of

model.zero_grad()
model.hidden = model.init_hidden()
tag_scores = model(sentence_in)
loss = loss_function(tag_scores, targets)
loss.backward()
optimizer.step()

you would do

model.hidden = model.init_hidden()
for (sentence_part, targets_part) in zip(sentence_parts, targets_parts):
    model.zero_grad()
    model.hidden[0].detach_()
    model.hidden[1].detach_()
    tag_scores_part = model(sentence_part)
    loss = loss_function(tag_scores_part, targets_part)
    loss.backward()
    optimizer.step()

The calls to .detach_() tell pytorch to stop backpropagating at that point.


(Sean Chen) #11

Great! Thank you again for your help.


#12

From my point of view, the batch_size and sequence_step used in this demo are 1 (i.e., mini_batch) and the length of a sentence, respectively.

In other words, a sentence is a batch (because loss is calculated after each batch) and the loss is propagating back through the whole input of a sentence.

I hope I got the right logic since I’m also a noob playing around with PyTorch.