Multiple forward passes, 1 backward pass instead of packed sequences

Hi, the packed sequence documentation is spotty and hard to work with (it’s also unclear how the y packing is handled in the loss function).

However, it should be mathematically equivalent to do batch_size forward passes and do 1 backward pass on the sum of all those losses. Will pytorch handle this?


loss_1 = model(x_1)
loss_2 = model(x_2)
loss_3 = model(x_3)
total_loss = loss_1 + loss_2 + loss_3


# equivalent to a padded batch_size = 3 input no?
1 Like

yes pytorch will handle this correctly.

1 Like

@smth thanks. So, just re-read the docs, looks like i can basically do this as well?

simulate a batch of 32 by running each seq at a time and calling .step() only once

for ex_i in range(0, 32):

    # gradients accumulate every call to backward
    # by calling step after 32 backward calls we can simulate a batch of 32
    loss = someLoss(f(x), y)