Hi, I’m new to the forum and hope this is the right place to ask.
I want to truncate my backwards pass in two different ways for two different passes, but don’t know how to split the forwards and backwards passes.
Here is a toy example (w1
and w2
are just weights):
x = f(upstream, w1)
z = g(x, w2)
loss1 = h(z)
loss2 = L(z)
I’m trying to efficiently compute these grads:
- grad(loss1) with respect to w1
- grad(loss2) with respect to w2
The model above only does one forwards pass through g
, but two backwards passes through f
.
I can trade two backwards passes through f
for two forwards passes through g
:
x = f(upstream, w1)
z1 = g(x, w2)
z2 = g(x.detach(), w2)
loss1 = h(z1)
loss2 = h(z2)
These two forwards passes through g
are identical though, so it’s still waste.
Is there a way to only do one backwards pass through f
and also only one forwards pass through g
?