Will doing two times forward and backward work fine?

I recently find that code like the following doesn’t work well for ddp (broadcast_buffers=True)

loss1 = L1(model(x1), y1)
loss2 = L1(model(x2), y2)
loss = loss1+loss2
loss.backward()
optimizer.step()
optimizer.zero_grad()

Someone recommend using another wrapper to do something like the following

class Wrapper():
    def __init__(self, model):
        self.model = model
    def forward(self, x1, x2, y1, y2):
        loss1 = L1(self.model(x1), y1)
        loss2 = L1(self.model(x2), y2)
        loss = loss1+loss2
        return loss

loss = wrapper(x1,x2,y1,y2)

loss.backward()
optimizer.step()
optimizer.zero_grad()

I don’t want to use the wrapper because my actual code is more complicated to implement a wrapper like the above.
But if I change the code to

# L1 represents L1 loss; model represents my model
loss1 = L1(model(x1), y1)
loss1.backward()

loss2 = L1(model(x2), y2)
loss2.backward()

optimizer.step()
optimizer.zero_grad()

Does the code work equally to the code using wrapper?

Hi Chener!

Yes, your third version where you call .backward() twice, is mathematically the same
as your first two versions where you sum the two losses and then call .backward()
once. (It could well have slightly different numerical round-off error.)

When you call .backward() twice (with no intervening .zero_grad()) the two gradients
get accumulated into the .grad properties of the various parameters. But the sum of the
gradients is the gradient of the sum, so you get the same result as summing the losses and
calculating the gradient of the sum with a single call to .backward().

Depending on your use case and the amount of memory you have, you could combine
your inputs, x1 and x2, into a single batch tensor, x_both, and, likewise, combine
your targets into a single batch tensor, y_both, and perform a single forward and
backward pass:

loss_both = L1 (model (x_both), y_both)
loss_both.backward()
optimizer.step()
optimizer.zero_grad()

Assuming that your loss function, L1(), uses (the equivalent of) reduction = 'sum'
(or nearly equivalently, reduction = 'mean'), this will be mathematically equivalent
to your three versions.

Best.

K. Frank

1 Like