Sharing computation between detached and non-detached variables

Hi, I’m new to the forum and hope this is the right place to ask.

I want to truncate my backwards pass in two different ways for two different passes, but don’t know how to split the forwards and backwards passes.

Here is a toy example (w1 and w2 are just weights):

x = f(upstream, w1)

z = g(x, w2)

loss1 = h(z)
loss2 = L(z)

I’m trying to efficiently compute these grads:

  • grad(loss1) with respect to w1
  • grad(loss2) with respect to w2

The model above only does one forwards pass through g, but two backwards passes through f.

I can trade two backwards passes through f for two forwards passes through g:

x = f(upstream, w1)

z1 = g(x, w2)
z2 = g(x.detach(), w2)

loss1 = h(z1)
loss2 = h(z2)

These two forwards passes through g are identical though, so it’s still waste.

Is there a way to only do one backwards pass through f and also only one forwards pass through g?