Intermediate data processing for consecutive networks and backwarding

Hello to all! I have encountered a tricky problem.

I have two networks n1 and n2 to be trained. The structure of my model is something like

def Model():
def forward(self):
output1 = n1(input1) # feed forward of network 1
input2 = function() # preparing input for network 2 containing iterations, which depends on output1
output2 = net2(input2) # feed forward of network 2

def function(self):
    # iteration using self.output1
    # takes relatively long to run (a few minutes)

def backward(self):
    self.loss = loss1 + loss2  # losses based on output1 and output2

The net1 was pre-trained for a few epoches to ensure meaningful output, and everything was correct and running well. Then net2 was added, and two networks are supposed to be trained simultaneously. The observation is that loss.backward() takes too long (around 20 minutes). I assume it has something to do with the function() in the forward() method, because when I changed the iteration to a smaller size, the backwarding was much faster. However, for the complete model, I cannot decrease the iteration size.

So my question is, does the loss.backward() method somehow “re-runs” the function() method in forward()? If it is true, any suggestions to avoid this problem, e.g. changing the structure of my Model?

Many thanks for any help!