Will pytorch reserve all temporary results in the forward pass?

Hello, everyone,
Since I have to deal with some variable-length inputs, I knew the following codes should work (input1 and input2 have the same feature dimension but different length)

loss1 = LossFunction(model(input1), target1)
loss1.backward()
loss2 = LossFunction(model(input2), target2)
loss2.backward()

However, I want to do some “batch-like” operations, and reduce the times of backward, then the codes looks like the following:

loss1 = LossFunction(model(input1), target1)
loss2 = LossFunction(model(input2), target2)

loss = loss1 + loss2
loss.backward()

Though no error was reported, it seems to be problematic. My question is that whether the temporal outputs saved in the forward pass for input1 will be overwritten by the ones for input2, which means the gradients will be computed only against input2 will the loss is computed over input1 and input2.

Thanks,
Shuai

1 Like

The temporary results are not linked to the nn.Module, that means that you can reuse it multiple times without any problem.
The two above snippets will give you the same gradients and are both correct.

3 Likes

That really helps, thanks a lot !

I was stuck thinking about the same thing for a while now. Just to confirm, using the same module at different positions in the graph will have separate gradients and forward buffers? And pytorch will handle averaging these gradients when backward is called?

Sorry if this is very obvious.

Yes, it will handle it properly.

1 Like