Hi, the packed sequence documentation is spotty and hard to work with (it’s also unclear how the y packing is handled in the loss function).
However, it should be mathematically equivalent to do batch_size forward passes and do 1 backward pass on the sum of all those losses. Will pytorch handle this?
ie:
loss_1 = model(x_1)
loss_2 = model(x_2)
loss_3 = model(x_3)
total_loss = loss_1 + loss_2 + loss_3total_loss.backward() # equivalent to a padded batch_size = 3 input no?