Convolutional LSTM - retain_graph Error

Hey everyone,
I am working on semantic segmentation and would like to extend an existing DeepLabV3+ (mobilenet backbone) with a recurrent unit (convolutional lstm).

My Idea was to concatinate the result of the segmentator at the current timestep T with its previous segmentation results (T-1 and T-2) and feed everything into the ConvLSTM (see picture).

During training I always get this error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time

Setting retain_graph to true however causes that cuda runs out of memory.

I’ve tried following this thread
and it seems like the problem has to do with hidden_state.detach().

My network looks like this:

class Deep_mobile_lstmV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = deeplabv3plus_mobilenet(num_classes=2, pretrained_backbone=True)
        
        self.lstm = ConvLSTM(input_dim=2, hidden_dim=[2], kernel_size=(3, 3), num_layers=1, batch_first=True,
                             bias=True,
                             return_all_layers=False)
        self.hidden = None

    def forward(self, x, *args):
        # set old predictions 
        old_pred = args[0] # a list of the old prediction tensors

        # get semantic segmentation by
        out = self.base(x)
        out = out.unsqueeze(1) #[bs,channels,H,W] --> [bs, timestep, channels, H,W]
        
        # set old predictions to 0 if they dont exists (for the first 2 frames)
        if len(args) != 0:
            if None in old_pred:
                for i in range(len(old_pred)):
                    old_pred[i] = torch.zeros_like(out)

            # concatinate old predictions with current predictions [_,1,_,_,_] --> [_,3,_,_,_]
            out = [out] + old_pred
            out = torch.cat(out, dim =1)

        out, self.hidden = self.lstm(out, self.hidden)
        out = out[0][:,-1,:,:,:] # return only 1 segmentation
        return out

(Please note that I left out some unimportant code that would have asserted, that the shapes match for concatination, so it is easier to read)

The Deeplab implementation is from this Git Repo and the ConvLSTM is from this repository (with line 142 in convlstm.py replaced by hidden_state=hidden_state)

My training loop looks like this:

for epoch in range(num_epochs):
    old_pred = [None, None]
    for batch in train_loader:
        images, labels = batch
        pred = net(images, old_pred)
        loss = criterion(pred, labels.long())
        optimizer.zero_grad()
        # loss.backward(retain_graph=True) leads to memory issues
        loss.backward() # throws error
        optimizer.step()
        
        old_pred[1] = old_pred[0]
        old_pred[0] = pred.unsqueeze(1)

if I change
out, self.hidden = self.lstm(out, self.hidden) into
out, self.hidden = self.lstm(out) it is training but than the information in self.hidden would be lost with each new sample, right?

I am stuck at this problem for days and have no idea how to solve this. Any help would be appreciated!

So what happens when you dont detach the hidden state is that the previous computation graph is attached to it (if retain_graph=True). So when you calculate your next hidden state, the previous computation graph is added to the current computation graph. After a couple of iterations the complete graph is so big that you run out of memory.
Not detaching will allow you to backprop through time (I will use BPTT as an abbreviation).

By instead detaching the hidden state you stop BPTT and the computation graph will not grow with each iteration. It is also possible to detach the hidden state at certain intervals to keep some BPTT.

out, self.hidden = self.lstm(out, self.hidden)
# If hidden is similar to LSTM it should be a tuple of two tensors
self.hidden = tuple(state.detach() for state in self.hidden)

This should run without needing to specify retain_graph=True

2 Likes

Thanks for the answer! Yes, hidden is indeed a tuple of two tensors.

Unfortunately, this does not work and I really dont get why not…still getting the error from above:

Does the position matter, where state.detach() is called?
At the moment I call detach in the forward() method of my network.
Does it need to be in the training loop?
so i would call

new_hidden = None
for batch in train_loader:
        hidden_state = new_hidden if new_hidden is not None else init_hidden_with_zeros()
        images, labels = batch
        pred, new_hidden = net(images, hidden_state, old_pred)
        new_hidden = tuple(state.detach() for state in new_hidden)
        (...)

so basicially shifting the process of “taking care of the hidden state” from within the forward method into the train loop…

Does that change anything? :smiley:

No problem, hopefully we can find the problem!

You just need to make sure you detach before using the hidden state again. It does not matter if you do it when you store the hidden state in the forward pass or when you retrieve the previous hidden state and feed it together with the new input to the LSTM.

The old_pred tensors need to be detached as well if it is concatenated with tensors in the current forward pass. Because the old_pred tensors still have a computation graph!

# In training loop
old_pred[1] = old_pred[0]
old_pred[0] = pred.unsqueeze(1).detach()

Hopefully it works this way. Gonna be unavailable soon so might take a day for me to respond next time :smiley:

1 Like

Mate you saved my day! detaching the pred works! Thank you so much :smiley:

Did I get this right that know backpropagation is only happening for the new Timestep T and not for T-1, T-2?

Awesome, glad I could help!

You have two recurrent parts, the previous prediction as well as the hidden state of the LSTM. Detaching these will stop gradients flowing back in time, so you only get gradients for the current timestep.

If you want to backprop in time you can always detach at a certain interval! For this to work correctly I think you need to do something like this:

detach_interval = 5
for i, batch in enumerate(data):
    input, target = batch
    pred = model(input)
    loss = criterion(pred, target)
    if i > 0 and i % detach_interval == 0:
        loss.backward(retain_graph=False)
        # Here you need to detach tensors that will be reused next forward pass
        ...
    else:
        loss.backward(retain_graph=True)
    optimizer.step()

Note that you need to retain the computation graph when calling backward(retain_graph=True) if you are going to reuse tensors for next forward pass without detaching them. On the last backward() before detaching, you can let the backprop clean out the graph.

Hopefully this makes sense!

1 Like

Makes sense to me,
thank you once more!
Recurrency is still confusing for me :smiley:

Hope you have a great weekend!