Hi, the packed sequence documentation is spotty and hard to work with (it’s also unclear how the y packing is handled in the loss function).

However, it should be mathematically equivalent to do batch_size forward passes and do 1 backward pass on the sum of all those losses. Will pytorch handle this?

ie:

loss_1 = model(x_1)

loss_2 = model(x_2)

loss_3 = model(x_3)

total_loss = loss_1 + loss_2 + loss_3`total_loss.backward() # equivalent to a padded batch_size = 3 input no?`