Cuda out of memory Error using retain_graph=True

Dear all,

I can not figure out how to get rid of the out of memory error:
RuntimeError: CUDA out of memory. Tried to allocate 7.50 MiB (GPU 0; 11.93 GiB total capacity; 5.47 GiB already allocated; 4.88 MiB free; 81.67 MiB cached).

In fact due to the recurrent architecture of my network I have to ‘retain_graph=True’ Otherwise I get the error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I keep running into this error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Here is the main of my function

for epoch in range(300):  # again, normally you would NOT do 300 epochs, it is toy data
        states = None #torch.empty().to(device)
        for idx, image in enumerate(loader):
            
            # Step 1. Remember that Pytorch accumulates gradients.
            
            # We need to clear them out before each instance
            # Step 3. Run our forward pass.

            tensor = image[0].clone().to(device)
            
            if states is None:
                states = prednet.get_initial_states(tensor)
            prednet.zero_grad()
            # tensor = tensor.reshape(tensor.shape[0],1,tensor.shape[1],tensor.shape[2],tensor.shape[3])
            tag_scores, states = prednet(tensor, states)
            # Step 4. Compute the loss, gradients, and update the parameters by
            #  calling optimizer.step()
            
            loss = loss_function(tag_scores, torch.zeros_like(tag_scores))
            print(loss)
            loss.backward(retain_graph=True)
            for state in states:
                    state.detach()
            optimizer.step()
            print('1 backward')
            torch.cuda.empty_cache()

Here is the forward function:

def forward(self, a, states = None):


            r_tm1 = states[:self.nb_layers]
            c_tm1 = states[self.nb_layers:2*self.nb_layers]
            e_tm1 = states[2*self.nb_layers:3*self.nb_layers]

            if self.extrap_start_time is not None:
                t = states[-1].copy()
                a = torch.switch(t >= self.t_extrap, states[-2], a)  # if past self.extrap_start_time, the previous prediction will be treated as the actual

            c = []
            r = []
            e = []


            for l in reversed(range(self.nb_layers)):
                inputs = [r_tm1[l], e_tm1[l]]
                if l < self.nb_layers - 1:
                    inputs.append(r_up)

                inputs = torch.cat(inputs, self.channel_axis)
                # print(inputs.shape)

                i = self.conv_layers['i'][l](inputs)
                f = self.conv_layers['f'][l](inputs)
                o = self.conv_layers['o'][l](inputs)

                # print('i',torch.isnan(i).any())
                # print('f',torch.isnan(f).any())
                # print('o',torch.isnan(o).any())
                # print('c',torch.isnan(o).any())
                # print('c',torch.isnan(self.conv_layers['c'][l](inputs)).any())
                _c = f * c_tm1[l] + i * self.conv_layers['c'][l](inputs)
                _r = o * self.LSTM_activation(_c)
                c.insert(0, _c)
                r.insert(0, _r)

                if l > 0:
                    r_up = self.upsample(_r)


            for l in range(self.nb_layers):
                ahat = self.conv_layers['ahat'][l](r[l])

                if l == 0:
                    value = torch.Tensor([self.pixel_max]).to(device)
                    ahat = torch.min(ahat, value.expand_as(ahat))
                    frame_prediction = ahat

                # compute errors

                e_up = self.error_activation(ahat - a)
                e_down = self.error_activation(a - ahat)

                e.append(torch.cat((e_up, e_down), dim=self.channel_axis))
                if l < self.nb_layers - 1:
                    a = self.conv_layers['a'][l](e[l])
                    a = self.pool(a)  # target for next layer

            if self.output_mode == 'prediction':
                output = frame_prediction

            else:
                for l in range(self.nb_layers):
                    layer_error = torch.mean(torch.flatten(e[l],start_dim=1), dim=-1, keepdim = True)
                    if l == 0:
                        all_error = layer_error
                    else:
                         all_error = torch.cat((all_error, layer_error), dim=-1)

                if self.output_mode == 'error' and image_n ==0:
                    output = all_error
                    output = output.unsqueeze(1)
                # elif self.output_mode == 'error':
                #     all_error = all_error.unsqueeze(1)
                #     output = torch.cat((output, all_error), dim=1)
                else:
                    output = torch.cat((torch.flatten(frame_prediction, start_dim=1), all_error), dim=-1)

            states = r + c + e
            if self.extrap_start_time is not None:
                states += [frame_prediction, t + 1]
            # return output, states

            return output, states

I assume you need the retain_graph=True setting, since you are not detaching the states tensor.
If this is your use case, you would have to lower the batch size to be able to store all computation graphs on the device.
If you don’t need to backpropagate through multiple steps, you might want to detach states via:

tag_scores, states = prednet(tensor, states.detach())

Thank you for your answer! I indeed need the backpropagation through time. But even reducing the batch will only delay the rise of memory error. Also I already tried to use

for state in states:
                    state.detach()

It doest not change the out of memory error after 5-10 batches.

state.detach() is not an inplace method and you would have to reassign the result as:

state = state.detach()

If that doesn’t help, could you post an executable code snippet?

Hey! thank you so much for your time. Still not working. Yet I am working on a 2-3GB database and the network is “fairly” complex. I can send you the code, but make an executable snippet that reproduces the error would take me lot of time. Especially since I do not know how to make one, that would take me a long while. I will put the solution here if I ever find one.

But isn’t there a way to set retain_graph=False from time to time to save memory ? I wanted to do something like this but every time I do this at a given step I get

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

The issue is raised, since the intermediate tensors are already freed (after you’ve used retain_graph=True), while your backward() call tries to backpropagate through operations where these intermedates are already deleted. Detaching the tensor would solve the problem (the backward pass would stop at this point and will not backpropagate further), but I understand that it might not be trivial if the code base is complicated.

If anyone ever comes by I ended up solving the issue by replacing the line

tag_scores, states = prednet(tensor, states)

with

tag_scores, states = prednet(tensor, [ state.detach() for state in states])

worked well ! it is similar solution to the one of @ptrblck just that states is a list of tensor
Thank you for the help!