Hi, I’m new to the forum and hope this is the right place to ask.
I want to truncate my backwards pass in two different ways for two different passes, but don’t know how to split the forwards and backwards passes.
Here is a toy example (w1 and w2 are just weights):
x = f(upstream, w1)
z = g(x, w2)
loss1 = h(z)
loss2 = L(z)
I’m trying to efficiently compute these grads:
- grad(loss1) with respect to w1
- grad(loss2) with respect to w2
The model above only does one forwards pass through g, but two backwards passes through f.
I can trade two backwards passes through f for two forwards passes through g:
x = f(upstream, w1)
z1 = g(x, w2)
z2 = g(x.detach(), w2)
loss1 = h(z1)
loss2 = h(z2)
These two forwards passes through g are identical though, so it’s still waste.
Is there a way to only do one backwards pass through f and also only one forwards pass through g?