Hi, I’m new to the forum and hope this is the right place to ask.
I want to truncate my backwards pass in two different ways for two different passes, but don’t know how to split the forwards and backwards passes.
Here is a toy example (
w2 are just weights):
x = f(upstream, w1) z = g(x, w2) loss1 = h(z) loss2 = L(z)
I’m trying to efficiently compute these grads:
- grad(loss1) with respect to w1
- grad(loss2) with respect to w2
The model above only does one forwards pass through
g, but two backwards passes through
I can trade two backwards passes through
f for two forwards passes through
x = f(upstream, w1) z1 = g(x, w2) z2 = g(x.detach(), w2) loss1 = h(z1) loss2 = h(z2)
These two forwards passes through
g are identical though, so it’s still waste.
Is there a way to only do one backwards pass through
f and also only one forwards pass through