LSTM with LBFGS optimizer, purpose of detach()

I am building an LSTM net for time-series prediction. I found an example from github and tried to implement it myself.

I have a class to monitor the training process. The training function is as below:

def trainModel(self, epochMax=100):
    for i in range(epochMax):
        self.h = torch.zeros(layerNo, datasetNo, D_H)
        self.c = torch.zeros(layerNo, datasetNo, D_H)
        self.h = self.h.to(device)
        self.c = self.c.to(device)

        def closure(): #this is needed for LBFGS optimizer
            #***Why do I need these four lines?
            self.h = self.h.detach()
            self.c = self.c.detach()
            self.h = self.h.requires_grad_()
            self.c = self.c.requires_grad_()
            
            yPred, h_temp, c_temp = self.model(self.XTrainAll, self.h, self.c)
            self.optimizer.zero_grad()
            self.h, self.c = h_temp, c_temp
            loss = self.lossFn(yPred, self.YTrainAll)
            # print('loss:', loss.item())
            loss.backward()
            return loss
        loss = self.optimizer.step(closure)

As seen from the example code, they do not implement the detach() function. But for my case, if these lines are missing, the error RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. appears.

I have gone through this post but I am still not clear about the difference between my code and the github code.

Also, is there a way to pass h and c into the closure() function without defining them as class variables? Python complains about undefined variables when I define them as just h and c outside of the function. Having global h at the beginning of closure() does not help.

Thanks in advance.

Regarding h,c variables, if there’s an assignment to h anywhere in the function, Python will treat h as a local symbol that shadows h from the outer context.

Assignment to an element of h does not trigger this behavior, so the the standard workaround is to wrap the variable from outer context into a singleton list.

myvar = 5
myvar = [myvar]
def closure():
  myvar[0]=6

closure()
print(myvar[0]) # => 6