How to handle learned initial state of neural network's gradient?

I have a neural network that needs to compute results sequentially, it reads in an initial state that’s learnable, and calculate step by step results based on last steps’ result. During training, there are no errors. But as I get to 2nd epoch, the time it takes to run backwards() starts to rise unreasonably. 1st epoch takes about 20s, but the 2nd epoch backward() takes hours. After searching online for a while, I found that it could be caused by not detaching some variables between epochs. I tried to detach the parameters shared between epochs, but don’t seem to have solved the bug. Any help is appreciated, thanks! Attached is the code for the complete model:

"""
Definitions for parametric functions M0, M1. Can use multiple features,
but only the first feature will be used to calculate labels.
"""
class M0_simplex(nn.Module):
    def __init__(self,num_nodes):
        super(M0_simplex, self).__init__()
        # c0*x + c1
        self.c0 = nn.Parameter(torch.tensor(0.0, requires_grad=True))
        self.c1 = nn.Parameter(torch.tensor(0.0, requires_grad=True))

    def reset_parameters(self):
        nn.init.constant_(self.c0,-1.0)
        nn.init.constant_(self.c1,1.0)
    
    def get_feature_num(self):
        return 1

    def forward(self,x):
        res = self.c0 * x + self.c1
        return res

class M1_simplex(nn.Module):
    def __init__(self,num_nodes):
        super(M1_simplex, self).__init__()
        # c0 * (x-x_neighb) + c1
        self.c0 = nn.Parameter(torch.tensor(0.0, requires_grad=True))
        self.c1 = nn.Parameter(torch.tensor(0.0, requires_grad=True))

    def forward(self,x,neighb_x):
        if x<0:
            return 0
        res = self.c0 * (x-neighb_x) + self.c1
        return res
    
    def reset_parameters(self):
        nn.init.uniform_(self.c0)
        nn.init.uniform_(self.c1)

"""
Definitions for architecture
"""
class info_flow(nn.Module):
    def __init__(self,adj, M0, M1, num_nodes):
        super(info_flow, self).__init__()
        #Current implementation for binary labels only. Need to fix
        self.adj = torch.tensor(adj)
        self.num_nodes = num_nodes
        self.M0 = M0(num_nodes)
        self.M1 = M1(num_nodes)
        self.features = self.M0.get_feature_num()
        self.threshold = nn.Parameter(torch.tensor( 0.0, requires_grad=True ))
        self.init_state = nn.Parameter(torch.zeros( (num_nodes,self.features), requires_grad=True ))
    
    def fit(self,y,optim,device,epochs=50,summary_writer=None,add_tag='',steps=-1):
        """ The function to fit information flow model
        Args:
            y: The target to be predicted. Needs to be one episode.
            optim: The optimizer for model
            device: The device for computation
            epochs: The epochs to fit
            summary_write: TensorboardX summary writer to keep track of acc
            add_tag: The tag to use when adding acc/loss info in summary
        Returns:
            loss: The loss at last epoch
            F1: The F1 score for each class
        """
        assert y.shape[0] == self.num_nodes
        self.train()
        if steps == -1:
            steps = y.shape[1]
        else:
            steps = min(steps, y.shape[1])
        y = y.to(device)
        self.adj = self.adj.to(device)
        
        for epoch in range(epochs):
            optim.zero_grad()
            labels = self.forward(steps,device)

            loss = F.nll_loss( torch.log(labels), y[:,:steps] )
            print('Finished loss compute',epoch)
            L = torch.argmax(labels,dim=1).detach().numpy()
            F1 = f1_score( y[:,:steps].cpu().numpy().flatten(), L.flatten() )
          
            loss.backward()
            optim.step()
            loss = loss.detach()

            if summary_writer!=None: #Write in summary loss and F1 for each class
                summary_writer.add_scalar( add_tag+'/loss', loss, epoch)
                summary_writer.add_scalar( add_tag+'/F1', F1, epoch)

            #Detach from next epoch's gradient computation
            self.init_state.detach_()
            self.threshold.detach_()

        self.eval()
        return loss, F1

    def forward_step(self,x,device=None):
        #x shape: [node_num, features]
        if device == None:
            res = torch.zeros( x.shape[0],x.shape[1]  )
        else:
            res = torch.zeros( x.shape[0],x.shape[1], device=device  )

        for node in range(x.shape[0]):
            self_feat = x[node,:]
            res[node,:] = res[node,:] + self.M0( self_feat )

            neighbs = (torch.abs(self.adj[:,node])>delta).nonzero()
            for neighb_node in neighbs:
                #Iterate over all neighbors to find sum of neighbor influences
                w = self.adj[node,neighb_node]
                res[node,:] = res[node,:] + self.M1( self_feat, x[neighb_node,:] ) * w

        #return updated values at each node
        return res

    def forward(self,time_steps,device=None,label=True,binary=False):
        """ This function produces the output for 'time_steps' steps
        Args:
            time_steps: How many steps of results to calculate
            device: The device on which computations are done
            label: A boolean variable for return type. If true, will return
            the 2 labels' probability. If false, will return the raw features.
            binary: A boolean variable for return type. If true, will return 
            max of 2 labels' probability.
        Returns:
            x: The raw features of size [num_node, time_steps, features] or 
            the label probabilities of size [num_node, classes, time_steps] or 
            the labels of size [num_node, time_steps]
        """
        x = [self.init_state.unsqueeze(1)]
        for step in range(time_steps-1):
            x.append(self.forward_step(x[step],device).unsqueeze(2))
        x = torch.cat( x, dim=1 )
        if label:
            #Threshold is the point where CSD vs. no CSD probs are 0.5 vs. 0.5
            x = 0.5*x[:,:,0]/self.threshold
            x[x>.95] = .95
            x[x<.15] = .05 #Don't force to be 0 for numerical stability
            x = torch.cat( [x.unsqueeze(2), (1-x).unsqueeze(2)], dim=2 ).transpose(2,1)
            if binary:
                x = torch.argmax( x, dim=1 )
        return x
            
    def reset_parameters(self):
        nn.init.uniform_(self.threshold)
        nn.init.uniform_(self.init_state)
        self.M0.reset_parameters()
        self.M1.reset_parameters()

Are you seeing an increase in your memory usage during the first iteration and is each iteration slowing down or is the slowdown suddenly in the second epoch?

Thanks for the reply. It slows down during the second epoch. There are about 400 time steps to compute for each epoch; in epoch 1, it takes about equal amount of time to compute each step, but in epoch 2 the time it takes to compute each step starts to increase at around step 300. And in fact I’ve run the model overnight but the optimizer.step() call in epoch 2 hasn’t returned yet. So, I can’t tell whether epoch 3 will take longer because I haven’t reached there yet.

Could you check, if your system is maybe overheating and might reduce the clocks of your hardware?
If that’s not the case, could you profile which part of the code is causing the slowdown?
You could use the timers from the ImageNet example to profile the data loading etc.

I checked with the admin and there was no problem with the server itself. As the code that takes the most time to run is backwards() call and step() call, do you mean I should profile the code inside backward() call and step() call?

Yes, profiling would be useful, to see which part of the code slows down.
Just to clarify the issue, do you see a sudden jump in epoch 2, iteration 300 or is the time increasing steadily after a certain iteration?

If the system isn’t a bottleneck due to overheating etc., you would often see a slowdown, if you are increasing the computation graph (e.g. if you are not detaching some states in an RNN), but this should also increase the device memory usage.

To clarify the question, the speed starts picking up steadily in epoch 2 but the impact doesn’t start to be visible until later in the epoch.

I found that loss.backward() calls are taking the most time to run. Do you have any quick guesses where it might be wrong? I tried detaching the learned state and the threshold (a learned variable defined in the model) in each epoch, but doesn’t seem to help. Should I try detaching all other variables in the model as well?

You could try to check all tensors for a valid .grad_fn and detach every tensor, which is not needed anymore. If the backward pass yields the slowdown this might indeed point towards a growing computation graph.
Do you see an increase in memory usage while you are training the model?

Just a wild guess, but could you also detach the loss when you are passing it to the summary_writer? I don’t know, how variables are handled by it.