Hi, I’m new to the forum and hope this is the right place to ask.

I want to truncate my backwards pass in two different ways for two different passes, but don’t know how to split the forwards and backwards passes.

Here is a toy example (`w1`

and `w2`

are just weights):

```
x = f(upstream, w1)
z = g(x, w2)
loss1 = h(z)
loss2 = L(z)
```

I’m trying to efficiently compute these grads:

- grad(loss1) with respect to w1
- grad(loss2) with respect to w2

The model above only does one forwards pass through `g`

, but two backwards passes through `f`

.

I can trade two backwards passes through `f`

for two forwards passes through `g`

:

```
x = f(upstream, w1)
z1 = g(x, w2)
z2 = g(x.detach(), w2)
loss1 = h(z1)
loss2 = h(z2)
```

These two forwards passes through `g`

are identical though, so it’s still waste.

Is there a way to only do one backwards pass through `f`

and also only one forwards pass through `g`

?