Backpropagate second loss without detach or second forward

Hi everyone,
can the following be done with a single forward pass instead of two forward passes?

outA = fnA(inputs)
loss = fnB(OutA, OutA) 
optim.zero_grad()
loss.backward()
optim.step()

outA = fnA(inputs)
loss = fnB(OutA.detach(), OutA) 
optim.zero_grad()
loss.backward()
optim.step()

It depends on the actual use case and what optim.step() updates. E.g. if any parameters used in fnA or fnB are updated, you won’t be able to use a single forward pass as the second forward pass would differ.